浏览代码

WireDataset

xue50 5 月之前
父节点
当前提交
ee04576538
共有 1 个文件被更改,包括 10 次插入10 次删除
  1. 10 10
      models/wirenet/wirepoint_rcnn.py

+ 10 - 10
models/wirenet/wirepoint_rcnn.py

@@ -716,16 +716,16 @@ if __name__ == '__main__':
             optimizer.step()
             writer_loss(writer, losses, epoch)
 
-            model.eval()
-            with torch.no_grad():
-                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
-                    pred = model(move_to_device(imgs, device))
-                    print(f"perd:{pred}")
-
-                # if batch_idx == 0:
-                #     viz = osp.join(cfg['io']['logdir'], "viz", f"{epoch}")
-                #     H = pred["wires"]
-                #     _plot_samples(0, 0, H, targets["wires"], f"{viz}/{epoch}")
+        model.eval()
+        with torch.no_grad():
+            for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+                pred = model(move_to_device(imgs, device))
+                print(f"perd:{pred}")
+
+            # if batch_idx == 0:
+            #     viz = osp.join(cfg['io']['logdir'], "viz", f"{epoch}")
+            #     H = pred["wires"]
+            #     _plot_samples(0, 0, H, targets["wires"], f"{viz}/{epoch}")
 
 # imgs, targets = next(iter(data_loader))
 #