Bladeren bron

debug train lines

RenLiqiang 5 maanden geleden
bovenliggende
commit
0c5896b17c
2 gewijzigde bestanden met toevoegingen van 6 en 3 verwijderingen
  1. 4 2
      models/line_detect/heads/head_losses.py
  2. 2 1
      models/line_detect/train.yaml

+ 4 - 2
models/line_detect/heads/head_losses.py

@@ -136,11 +136,13 @@ def line_points_to_heatmap(keypoints, rois, heatmap_size):
     all_roi_heatmap = []
     for roi, heatmap in zip(rois, gs):
         # print(f'heatmap:{heatmap.shape}')
+        # show_heatmap(heatmap,'target')
         heatmap = heatmap.unsqueeze(0)
+
         x1, y1, x2, y2 = map(int, roi)
         roi_heatmap = torch.zeros_like(heatmap)
         roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
-        # show_heatmap(roi_heatmap,'roi_heatmap')
+        # show_heatmap(roi_heatmap[0],'roi_heatmap')
         all_roi_heatmap.append(roi_heatmap)
 
     all_roi_heatmap = torch.cat(all_roi_heatmap)
@@ -403,7 +405,7 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     # line_logits = line_logits.view(N * K, H * W)
     # print(f'line_logits[valid]:{line_logits[valid].shape}')
     print(f'loss1 line_logits:{line_logits.shape}')
-    line_logits = line_logits[:,2,:,:]
+    line_logits = line_logits[:,1,:,:]
     print(f'loss2 line_logits:{line_logits.shape}')
 
     # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])

+ 2 - 1
models/line_detect/train.yaml

@@ -1,6 +1,7 @@
 io:
   logdir: train_results
-  datadir: /data/share/rlq/datasets/Dataset_correct_xanylabel
+#  datadir: /data/share/rlq/datasets/Dataset_correct_xanylabel
+  datadir: \\192.168.50.222/share/rlq/datasets/Dataset_correct_xanylabel
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000