ソースを参照

修复标注点映射问题

RenLiqiang 5 ヶ月 前
コミット
af4cac8846
2 ファイル変更49 行追加5 行削除
  1. 47 3
      models/line_detect/loi_heads.py
  2. 2 2
      models/line_detect/train.yaml

+ 47 - 3
models/line_detect/loi_heads.py

@@ -129,7 +129,6 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
     )
     return mask_loss
 
-
 def line_points_to_heatmap(keypoints, rois, heatmap_size):
     # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
     print(f'rois:{rois.shape}')
@@ -150,10 +149,54 @@ def line_points_to_heatmap(keypoints, rois, heatmap_size):
     x = keypoints[..., 0]
     y = keypoints[..., 1]
 
+    gs=generate_gaussian_heatmaps(x,y,heatmap_size,1.0)
+    # show_heatmap(gs[0],'target')
+    all_roi_heatmap=[]
+    for roi ,heatmap in zip(rois,gs):
+        # print(f'heatmap:{heatmap.shape}')
+        heatmap=heatmap.unsqueeze(0)
+        x1, y1, x2, y2 = map(int, roi)
+        roi_heatmap = torch.zeros_like(heatmap)
+        roi_heatmap[..., y1:y2+1, x1:x2+1]=heatmap[..., y1:y2+1, x1:x2+1]
+        # show_heatmap(roi_heatmap,'roi_heatmap')
+        all_roi_heatmap.append(roi_heatmap)
+
+    all_roi_heatmap=torch.cat(all_roi_heatmap)
+    print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
+
+
+    return all_roi_heatmap
+
+
+"""
+修改适配的原结构的点 转热图,适用于带roi_pool版本的
+"""
+def line_points_to_heatmap_(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    print(f'rois:{rois.shape}')
+    print(f'heatmap_size:{heatmap_size}')
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    print(f'keypoints.shape:{keypoints.shape}')
+    # batch_size, num_keypoints, _ = keypoints.shape
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
     # gs=generate_gaussian_heatmaps(x,y,512,1.0)
+
+
     # print(f'gs_heatmap shape:{gs.shape}')
     #
-    # show_heatmap(gs,'target')
+    # show_heatmap(gs[0],'target')
 
     x_boundary_inds = x == rois[:, 2][:, None]
     y_boundary_inds = y == rois[:, 3][:, None]
@@ -174,7 +217,7 @@ def line_points_to_heatmap(keypoints, rois, heatmap_size):
 
     gs_heatmap=generate_gaussian_heatmaps(x,y,heatmap_size,1.0)
 
-    # show_heatmap(gs_heatmap[0],'feature')
+    show_heatmap(gs_heatmap[0],'feature')
 
     # print(f'gs_heatmap:{gs_heatmap.shape}')
     #
@@ -1266,6 +1309,7 @@ class RoIHeads(nn.Module):
                     raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
 
                 gt_lines = [t["lines"] for t in targets]
+                print(f'gt_lines:{gt_lines[0].shape}')
                 h, w = targets[0]["img_size"]
                 img_size = h
                 rcnn_loss_line = lines_point_pair_loss(

+ 2 - 2
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: /data/share/zyh/202507/a_dataset
+  datadir: \\192.168.50.222/share/zyh/202507/a_dataset
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
@@ -11,7 +11,7 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 4
+  batch_size: 1
   max_epoch: 80000
   augmentation: True
   optim: