xue50 hai 5 meses
pai
achega
2534c9a888

BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824076.Li.85572.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824152.Li.34428.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824166.Li.84204.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824178.Li.51724.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824188.Li.17696.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824200.Li.58920.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824231.Li.107936.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824395.Li.66368.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824434.Li.6216.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824446.Li.72172.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824523.Li.84872.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824551.Li.54712.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824567.Li.28980.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824590.Li.53140.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824603.Li.23764.0


BIN=BIN
models/wirenet/logs/events.out.tfevents.1733824617.Li.132768.0


+ 20 - 27
models/wirenet/wirepoint_rcnn.py

@@ -32,7 +32,7 @@ import matplotlib.pyplot as plt
 import matplotlib as mpl
 from skimage import io
 import os.path as osp
-
+from torchvision.utils import draw_bounding_boxes
 
 FEATURE_DIM = 8
 
@@ -583,15 +583,15 @@ def imshow(im):
     plt.colorbar(sm, fraction=0.046)
     plt.xlim([0, im.shape[0]])
     plt.ylim([im.shape[0], 0])
+    plt.show()
 
 
-def _plot_samples(self, i, index, result, targets, prefix):
-    fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
-    img = io.imread(fn)
-    imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
-
+def _plot_samples(img, i, result, prefix, epoch):
+    print(f"prefix:{prefix}")
     def draw_vecl(lines, sline, juncs, junts, fn):
-        imshow(img)
+        if not os.path.exists(fn):
+            os.makedirs(fn)
+        imshow(img.permute(1, 2, 0))
         if len(lines) > 0 and not (lines[0] == 0).all():
             for i, ((a, b), s) in enumerate(zip(lines, sline)):
                 if i > 0 and (lines[i] == lines[0]).all():
@@ -609,27 +609,18 @@ def _plot_samples(self, i, index, result, targets, prefix):
                 plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
         plt.savefig(fn), plt.close()
 
-    junc = targets[i]["junc"].cpu().numpy() * 4
-    jtyp = targets[i]["jtyp"].cpu().numpy()
-    juncs = junc[jtyp == 0]
-    junts = junc[jtyp == 1]
     rjuncs = result["juncs"][i].cpu().numpy() * 4
     rjunts = None
     if "junts" in result:
         rjunts = result["junts"][i].cpu().numpy() * 4
 
-    lpre = targets[i]["lpre"].cpu().numpy() * 4
-    vecl_target = targets[i]["lpre_label"].cpu().numpy()
     vecl_result = result["lines"][i].cpu().numpy() * 4
     score = result["score"][i].cpu().numpy()
-    lpre = lpre[vecl_target == 1]
 
-    draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
     draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
 
-    img = cv2.imread(f"{prefix}_vecl_a.jpg")
     img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
-    self.writer.add_images(f"{self.epoch}", torch.tensor([img, img1]), dataformats='NHWC')
+    writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch)
 
 
 if __name__ == '__main__':
@@ -716,16 +707,18 @@ if __name__ == '__main__':
             optimizer.step()
             writer_loss(writer, losses, epoch)
 
-        model.eval()
-        with torch.no_grad():
-            for batch_idx, (imgs, targets) in enumerate(data_loader_val):
-                pred = model(move_to_device(imgs, device))
-                print(f"perd:{pred}")
-
-            # if batch_idx == 0:
-            #     viz = osp.join(cfg['io']['logdir'], "viz", f"{epoch}")
-            #     H = pred["wires"]
-            #     _plot_samples(0, 0, H, targets["wires"], f"{viz}/{epoch}")
+            model.eval()
+            with torch.no_grad():
+                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+                    pred = model(move_to_device(imgs, device))
+                    # print(f"pred:{pred}")
+
+                    if batch_idx == 0:
+                        result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores']
+                        print(imgs[0].shape)  # [3,512,512]
+                        # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3)
+                        _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch)
+
 
 # imgs, targets = next(iter(data_loader))
 #