Explorar el Código

tensorboard 加入lmap jmap

RenLiqiang hace 8 meses
padre
commit
2c1a9b7db1
Se han modificado 4 ficheros con 21 adiciones y 12 borrados
  1. 9 9
      models/line_detect/111.py
  2. 1 0
      models/line_detect/roi_heads.py
  3. 1 1
      models/line_detect/train.yaml
  4. 10 2
      utils/log_util.py

+ 9 - 9
models/line_detect/111.py

@@ -231,15 +231,15 @@ if __name__ == '__main__':
 
     # model = LineNet('line_net.yaml')
     model=linenet_resnet50_fpn().to(device)
-    #model=linenet_resnet18_fpn()
-    # trainer = Trainer()
-    # trainer.train_cfg(model,cfg='./train.yaml')
-    # model.train_by_cfg(cfg='train.yaml')
-    # trainer = Trainer()
-    # trainer.train_cfg(model=model, cfg='train.yaml')
-    pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
-    img_path = r"C:\Users\m2337\Desktop\p\新建文件夹\2025-03-25-16-10-00_SaveLeftImage.png"
-    model.predict(pt_path, model, img_path, type=1, threshold=0, save_path=None, show=True)
+    model=linenet_resnet18_fpn()
+    trainer = Trainer()
+    trainer.train_cfg(model,cfg='./train.yaml')
+    model.train_by_cfg(cfg='train.yaml')
+    trainer = Trainer()
+    trainer.train_cfg(model=model, cfg='train.yaml')
+    # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
+    # img_path = r"C:\Users\m2337\Desktop\p\新建文件夹\2025-03-25-16-10-00_SaveLeftImage.png"
+    # model.predict(pt_path, model, img_path, type=1, threshold=0, save_path=None, show=True)
 
 
 

+ 1 - 0
models/line_detect/roi_heads.py

@@ -1075,6 +1075,7 @@ class RoIHeads(nn.Module):
             else:
 
                 pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
+                result.append(outputs)
                 result.append(pred)
                 loss_wirepoint = {}
             losses.update(loss_wirepoint)

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: D:\all\1Desktop\20250320data\0322_
+  datadir: I:/datasets/4_23jiagonggongjian
 #  datadir: I:\datasets\wirenet_1000
   resume_from:
   num_workers: 8

+ 10 - 2
utils/log_util.py

@@ -95,6 +95,14 @@ def show_line(img, pred, epoch, writer):
 
     PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
     # print(f'pred[1]:{pred[1]}')
+    heatmaps=pred[-2][0]
+    print(f'heatmaps:{heatmaps.shape}')
+    jmap = heatmaps[1: 2].cpu().detach()
+    lmap = heatmaps[2: 3].cpu().detach()
+    writer.add_image("z-jmap", jmap, epoch)
+    writer.add_image("z-lmap", lmap, epoch)
+    # plt.imshow(lmap)
+    # plt.show()
     H = pred[-1]['wires']
     lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
     scores = H["score"][0].cpu().numpy()
@@ -108,7 +116,7 @@ def show_line(img, pred, epoch, writer):
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
     nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
 
-    for i, t in enumerate([0.001]):
+    for i, t in enumerate([0]):
         plt.gca().set_axis_off()
         plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
         plt.margins(0, 0)
@@ -138,4 +146,4 @@ def show_line(img, pred, epoch, writer):
 
         img2 = transforms.ToTensor()(img_rgb)
 
-        writer.add_image("z-output", img2, epoch)
+        writer.add_image("z-output", img2, epoch)