|
@@ -32,7 +32,7 @@ import matplotlib.pyplot as plt
|
|
|
import matplotlib as mpl
|
|
|
from skimage import io
|
|
|
import os.path as osp
|
|
|
-
|
|
|
+from torchvision.utils import draw_bounding_boxes
|
|
|
|
|
|
FEATURE_DIM = 8
|
|
|
|
|
@@ -583,15 +583,15 @@ def imshow(im):
|
|
|
plt.colorbar(sm, fraction=0.046)
|
|
|
plt.xlim([0, im.shape[0]])
|
|
|
plt.ylim([im.shape[0], 0])
|
|
|
+ plt.show()
|
|
|
|
|
|
|
|
|
-def _plot_samples(self, i, index, result, targets, prefix):
|
|
|
- fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
|
|
|
- img = io.imread(fn)
|
|
|
- imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
|
|
|
-
|
|
|
+def _plot_samples(img, i, result, prefix, epoch):
|
|
|
+ print(f"prefix:{prefix}")
|
|
|
def draw_vecl(lines, sline, juncs, junts, fn):
|
|
|
- imshow(img)
|
|
|
+ if not os.path.exists(fn):
|
|
|
+ os.makedirs(fn)
|
|
|
+ imshow(img.permute(1, 2, 0))
|
|
|
if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
if i > 0 and (lines[i] == lines[0]).all():
|
|
@@ -609,27 +609,18 @@ def _plot_samples(self, i, index, result, targets, prefix):
|
|
|
plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
|
|
|
plt.savefig(fn), plt.close()
|
|
|
|
|
|
- junc = targets[i]["junc"].cpu().numpy() * 4
|
|
|
- jtyp = targets[i]["jtyp"].cpu().numpy()
|
|
|
- juncs = junc[jtyp == 0]
|
|
|
- junts = junc[jtyp == 1]
|
|
|
rjuncs = result["juncs"][i].cpu().numpy() * 4
|
|
|
rjunts = None
|
|
|
if "junts" in result:
|
|
|
rjunts = result["junts"][i].cpu().numpy() * 4
|
|
|
|
|
|
- lpre = targets[i]["lpre"].cpu().numpy() * 4
|
|
|
- vecl_target = targets[i]["lpre_label"].cpu().numpy()
|
|
|
vecl_result = result["lines"][i].cpu().numpy() * 4
|
|
|
score = result["score"][i].cpu().numpy()
|
|
|
- lpre = lpre[vecl_target == 1]
|
|
|
|
|
|
- draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
|
|
|
draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
|
|
|
|
|
|
- img = cv2.imread(f"{prefix}_vecl_a.jpg")
|
|
|
img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
|
|
|
- self.writer.add_images(f"{self.epoch}", torch.tensor([img, img1]), dataformats='NHWC')
|
|
|
+ writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
@@ -716,16 +707,18 @@ if __name__ == '__main__':
|
|
|
optimizer.step()
|
|
|
writer_loss(writer, losses, epoch)
|
|
|
|
|
|
- model.eval()
|
|
|
- with torch.no_grad():
|
|
|
- for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
- pred = model(move_to_device(imgs, device))
|
|
|
- print(f"perd:{pred}")
|
|
|
-
|
|
|
- # if batch_idx == 0:
|
|
|
- # viz = osp.join(cfg['io']['logdir'], "viz", f"{epoch}")
|
|
|
- # H = pred["wires"]
|
|
|
- # _plot_samples(0, 0, H, targets["wires"], f"{viz}/{epoch}")
|
|
|
+ model.eval()
|
|
|
+ with torch.no_grad():
|
|
|
+ for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
+ pred = model(move_to_device(imgs, device))
|
|
|
+ # print(f"pred:{pred}")
|
|
|
+
|
|
|
+ if batch_idx == 0:
|
|
|
+ result = pred[1]['wires'] # pred[0].keys() ['boxes', 'labels', 'scores']
|
|
|
+ print(imgs[0].shape) # [3,512,512]
|
|
|
+ # imshow(imgs[0].permute(1, 2, 0)) # 改为(512, 512, 3)
|
|
|
+ _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch)
|
|
|
+
|
|
|
|
|
|
# imgs, targets = next(iter(data_loader))
|
|
|
#
|