|
@@ -33,6 +33,8 @@ import matplotlib as mpl
|
|
|
from skimage import io
|
|
|
import os.path as osp
|
|
|
from torchvision.utils import draw_bounding_boxes
|
|
|
+from torchvision import transforms
|
|
|
+from models.wirenet.postprocess import postprocess
|
|
|
|
|
|
FEATURE_DIM = 8
|
|
|
|
|
@@ -583,32 +585,91 @@ def imshow(im):
|
|
|
plt.colorbar(sm, fraction=0.046)
|
|
|
plt.xlim([0, im.shape[0]])
|
|
|
plt.ylim([im.shape[0], 0])
|
|
|
- plt.show()
|
|
|
+ # plt.show()
|
|
|
+
|
|
|
+
|
|
|
+# def _plot_samples(img, i, result, prefix, epoch):
|
|
|
+# print(f"prefix:{prefix}")
|
|
|
+# def draw_vecl(lines, sline, juncs, junts, fn):
|
|
|
+# directory = os.path.dirname(fn)
|
|
|
+# if not os.path.exists(directory):
|
|
|
+# os.makedirs(directory)
|
|
|
+# imshow(img.permute(1, 2, 0))
|
|
|
+# if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
+# for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
+# if i > 0 and (lines[i] == lines[0]).all():
|
|
|
+# break
|
|
|
+# plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
|
|
|
+# if not (juncs[0] == 0).all():
|
|
|
+# for i, j in enumerate(juncs):
|
|
|
+# if i > 0 and (i == juncs[0]).all():
|
|
|
+# break
|
|
|
+# plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
|
|
|
+# if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
|
|
|
+# for i, j in enumerate(junts):
|
|
|
+# if i > 0 and (i == junts[0]).all():
|
|
|
+# break
|
|
|
+# plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
|
|
|
+# plt.savefig(fn), plt.close()
|
|
|
+#
|
|
|
+# rjuncs = result["juncs"][i].cpu().numpy() * 4
|
|
|
+# rjunts = None
|
|
|
+# if "junts" in result:
|
|
|
+# rjunts = result["junts"][i].cpu().numpy() * 4
|
|
|
+#
|
|
|
+# vecl_result = result["lines"][i].cpu().numpy() * 4
|
|
|
+# score = result["score"][i].cpu().numpy()
|
|
|
+#
|
|
|
+# draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
|
|
|
+#
|
|
|
+# img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
|
|
|
+# writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch)
|
|
|
|
|
|
+def _plot_samples(img, i, result, prefix, epoch, writer):
|
|
|
+ # print(f"prefix:{prefix}")
|
|
|
|
|
|
-def _plot_samples(img, i, result, prefix, epoch):
|
|
|
- print(f"prefix:{prefix}")
|
|
|
def draw_vecl(lines, sline, juncs, junts, fn):
|
|
|
- if not os.path.exists(fn):
|
|
|
- os.makedirs(fn)
|
|
|
- imshow(img.permute(1, 2, 0))
|
|
|
+ # 确保目录存在
|
|
|
+ directory = os.path.dirname(fn)
|
|
|
+ if not os.path.exists(directory):
|
|
|
+ os.makedirs(directory)
|
|
|
+
|
|
|
+ # 绘制图像
|
|
|
+ plt.figure()
|
|
|
+ plt.imshow(img.permute(1, 2, 0).cpu().numpy())
|
|
|
+ plt.axis('off') # 可选:关闭坐标轴
|
|
|
+
|
|
|
if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
- for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
- if i > 0 and (lines[i] == lines[0]).all():
|
|
|
+ for idx, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
+ if idx > 0 and (lines[idx] == lines[0]).all():
|
|
|
break
|
|
|
- plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
|
|
|
+ plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=1)
|
|
|
+
|
|
|
if not (juncs[0] == 0).all():
|
|
|
- for i, j in enumerate(juncs):
|
|
|
- if i > 0 and (i == juncs[0]).all():
|
|
|
+ for idx, j in enumerate(juncs):
|
|
|
+ if idx > 0 and (j == juncs[0]).all():
|
|
|
break
|
|
|
- plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
|
|
|
+ plt.scatter(j[1], j[0], c="red", s=20, zorder=100)
|
|
|
+
|
|
|
if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
|
|
|
- for i, j in enumerate(junts):
|
|
|
- if i > 0 and (i == junts[0]).all():
|
|
|
+ for idx, j in enumerate(junts):
|
|
|
+ if idx > 0 and (j == junts[0]).all():
|
|
|
break
|
|
|
- plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
|
|
|
- plt.savefig(fn), plt.close()
|
|
|
+ plt.scatter(j[1], j[0], c="blue", s=20, zorder=100)
|
|
|
+
|
|
|
+ # plt.show()
|
|
|
+
|
|
|
+ # 将matplotlib图像转换为numpy数组
|
|
|
+ plt.tight_layout()
|
|
|
+ fig = plt.gcf()
|
|
|
+ fig.canvas.draw()
|
|
|
+ image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
|
|
|
+ fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
+ plt.close()
|
|
|
|
|
|
+ return image_from_plot
|
|
|
+
|
|
|
+ # 获取结果数据并转换为numpy数组
|
|
|
rjuncs = result["juncs"][i].cpu().numpy() * 4
|
|
|
rjunts = None
|
|
|
if "junts" in result:
|
|
@@ -617,10 +678,62 @@ def _plot_samples(img, i, result, prefix, epoch):
|
|
|
vecl_result = result["lines"][i].cpu().numpy() * 4
|
|
|
score = result["score"][i].cpu().numpy()
|
|
|
|
|
|
- draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
|
|
|
-
|
|
|
- img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
|
|
|
- writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch)
|
|
|
+ # 调用绘图函数并获取图像
|
|
|
+ image_path = f"{prefix}_vecl_b.jpg"
|
|
|
+ image_array = draw_vecl(vecl_result, score, rjuncs, rjunts, image_path)
|
|
|
+
|
|
|
+ # 将numpy数组转换为torch tensor,并写入TensorBoard
|
|
|
+ image_tensor = transforms.ToTensor()(image_array)
|
|
|
+ writer.add_image(f'output_epoch', image_tensor, global_step=epoch)
|
|
|
+ writer.add_image(f'ori_epoch', img, global_step=epoch)
|
|
|
+
|
|
|
+
|
|
|
+def show_line(img, pred, prefix, epoch, write):
|
|
|
+ fn = f"{prefix}_line.jpg"
|
|
|
+ directory = os.path.dirname(fn)
|
|
|
+ if not os.path.exists(directory):
|
|
|
+ os.makedirs(directory)
|
|
|
+ print(fn)
|
|
|
+ PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
|
|
|
+ H = pred
|
|
|
+
|
|
|
+ im = img.permute(1, 2, 0)
|
|
|
+
|
|
|
+ lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
|
|
|
+ scores = H["score"][0].cpu().numpy()
|
|
|
+ for i in range(1, len(lines)):
|
|
|
+ if (lines[i] == lines[0]).all():
|
|
|
+ lines = lines[:i]
|
|
|
+ scores = scores[:i]
|
|
|
+ break
|
|
|
+
|
|
|
+ # postprocess lines to remove overlapped lines
|
|
|
+ diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
+ nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
|
|
|
+
|
|
|
+ for i, t in enumerate([0.5]):
|
|
|
+ plt.gca().set_axis_off()
|
|
|
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
|
+ plt.margins(0, 0)
|
|
|
+ for (a, b), s in zip(nlines, nscores):
|
|
|
+ if s < t:
|
|
|
+ continue
|
|
|
+ plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
|
|
|
+ plt.scatter(a[1], a[0], **PLTOPTS)
|
|
|
+ plt.scatter(b[1], b[0], **PLTOPTS)
|
|
|
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
|
|
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
|
|
+ plt.imshow(im)
|
|
|
+ plt.savefig(fn, bbox_inches="tight")
|
|
|
+ plt.show()
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+
|
|
|
+ img2 = cv2.imread(fn) # 预测图
|
|
|
+ # img1 = im.resize(img2.shape) # 原图
|
|
|
+
|
|
|
+ # writer.add_images(f"{epoch}", torch.tensor([img1, img2]), dataformats='NHWC')
|
|
|
+ writer.add_image("output", img2, epoch)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
@@ -698,27 +811,27 @@ if __name__ == '__main__':
|
|
|
print(f"epoch:{epoch}")
|
|
|
model.train()
|
|
|
|
|
|
- for imgs, targets in data_loader_train:
|
|
|
- losses = model(move_to_device(imgs, device), move_to_device(targets, device))
|
|
|
- loss = _loss(losses)
|
|
|
- print(loss)
|
|
|
- optimizer.zero_grad()
|
|
|
- loss.backward()
|
|
|
- optimizer.step()
|
|
|
- writer_loss(writer, losses, epoch)
|
|
|
-
|
|
|
- model.eval()
|
|
|
- with torch.no_grad():
|
|
|
- for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
- pred = model(move_to_device(imgs, device))
|
|
|
- # print(f"pred:{pred}")
|
|
|
-
|
|
|
- if batch_idx == 0:
|
|
|
- result = pred[1]['wires'] # pred[0].keys() ['boxes', 'labels', 'scores']
|
|
|
- print(imgs[0].shape) # [3,512,512]
|
|
|
- # imshow(imgs[0].permute(1, 2, 0)) # 改为(512, 512, 3)
|
|
|
- _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch)
|
|
|
+ # for imgs, targets in data_loader_train:
|
|
|
+ # losses = model(move_to_device(imgs, device), move_to_device(targets, device))
|
|
|
+ # loss = _loss(losses)
|
|
|
+ # print(loss)
|
|
|
+ # optimizer.zero_grad()
|
|
|
+ # loss.backward()
|
|
|
+ # optimizer.step()
|
|
|
+ # writer_loss(writer, losses, epoch)
|
|
|
|
|
|
+ model.eval()
|
|
|
+ with torch.no_grad():
|
|
|
+ for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
+ pred = model(move_to_device(imgs, device))
|
|
|
+ # print(f"pred:{pred}")
|
|
|
+
|
|
|
+ if batch_idx == 0:
|
|
|
+ result = pred[1]['wires'] # pred[0].keys() ['boxes', 'labels', 'scores']
|
|
|
+ print(imgs[0].shape) # [3,512,512]
|
|
|
+ # imshow(imgs[0].permute(1, 2, 0)) # 改为(512, 512, 3)
|
|
|
+ _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer)
|
|
|
+ # show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer)
|
|
|
|
|
|
# imgs, targets = next(iter(data_loader))
|
|
|
#
|