فهرست منبع

调试gt_arc热度图

RenLiqiang 5 ماه پیش
والد
کامیت
fa2f131ee5
4فایلهای تغییر یافته به همراه87 افزوده شده و 13 حذف شده
  1. 72 2
      models/line_detect/heads/head_losses.py
  2. 7 4
      models/line_detect/line_dataset.py
  3. 6 6
      models/line_detect/loi_heads.py
  4. 2 1
      models/line_detect/train.yaml

+ 72 - 2
models/line_detect/heads/head_losses.py

@@ -232,6 +232,75 @@ def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, dev
 
     assert xs.shape == ys.shape, "x and y must have the same shape"
     print(f'xs:{xs.shape}')
+    xs=xs.squeeze(1)
+    ys = ys.squeeze(1)
+    print(f'xs1:{xs.shape}')
+    N = xs.shape[0]
+    print(f'N:{N},num_points:{num_points}')
+
+    # 创建网格
+    grid_y, grid_x = torch.meshgrid(
+        torch.arange(heatmap_size, device=device),
+        torch.arange(heatmap_size, device=device),
+        indexing='ij'
+    )
+
+    # print(f'heatmap_size:{heatmap_size}')
+    # 初始化输出热图
+    combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
+
+    for i in range(N):
+        heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
+        for j in range(num_points):
+            mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
+            mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
+            # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
+
+            # 计算距离平方
+            dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
+
+            # 计算高斯分布
+            heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
+
+            heatmap+=heatmap1
+
+
+        # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
+        # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
+        #
+        # # 计算距离平方
+        # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
+        #
+        # # 计算高斯分布
+        # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
+        #
+        # heatmap = heatmap1 + heatmap2
+
+        # 将当前热图累加到结果中
+        combined_heatmap[i] = heatmap
+
+    return combined_heatmap
+
+def generate_mask_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
+    """
+    为一组点生成并合并高斯热图。
+
+    Args:
+        xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
+        ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
+        heatmap_size (int): 热图大小 H=W
+        sigma (float): 高斯核标准差
+        device (str): 设备类型 ('cpu' or 'cuda')
+
+    Returns:
+        Tensor: 形状为 (H, W) 的合并后的热图
+    """
+
+    assert xs.shape == ys.shape, "x and y must have the same shape"
+    print(f'xs:{xs.shape}')
+    xs=xs.squeeze(1)
+    ys = ys.squeeze(1)
+    print(f'xs1:{xs.shape}')
     N = xs.shape[0]
     print(f'N:{N},num_points:{num_points}')
 
@@ -471,8 +540,9 @@ def arc_points_to_heatmap(keypoints, rois, heatmap_size):
 
     x = keypoints[..., 0].unsqueeze(1)
     y = keypoints[..., 1].unsqueeze(1)
-
-    gs = generate_gaussian_heatmaps(x, y, num_points=10, heatmap_size=heatmap_size, sigma=1.0)
+    num_points=x.shape[2]
+    print(f'num_points:{num_points}')
+    gs = generate_mask_gaussian_heatmaps(x, y, num_points=num_points, heatmap_size=heatmap_size, sigma=1.0)
     # show_heatmap(gs[0],'target')
     all_roi_heatmap = []
     for roi, heatmap in zip(rois, gs):

+ 7 - 4
models/line_detect/line_dataset.py

@@ -88,12 +88,13 @@ class LineDataset(BaseDataset):
             a = torch.full((lines.shape[0],), 2).unsqueeze(1)
             lines = torch.cat((lines, a), dim=1)
             target["lines"] = lines.to(torch.float32).view(-1, 2, 3)
+            # print(f'lines shape:{ target["lines"].shape}')
 
         if arc_mask is not None:
             target['arc_mask']=arc_mask
-            print(f'arc_mask dataset')
-        else:
-            print(f'not arc_mask dataset')
+            # print(f'arc_mask dataset')
+        # else:
+        #     print(f'not arc_mask dataset')
 
         target["boxes"]=boxes
         target["labels"]=labels
@@ -225,13 +226,15 @@ def get_boxes_lines(objs,shape):
         line_point_pairs=None
     else:
         line_point_pairs=torch.tensor(line_point_pairs)
+        # print(f'line_point_pairs:{line_point_pairs.shape},{line_point_pairs.dtype}')
 
     # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}')
 
     if len(line_mask)==0:
         line_mask=None
     else:
-        line_mask=torch.tensor(line_mask)
+        line_mask=torch.tensor(line_mask,dtype=torch.float32)
+        print(f'arc_mask shape :{line_mask.shape},{line_mask.dtype}')
     return boxes,line_point_pairs,points,line_mask, labels
 
 if __name__ == '__main__':

+ 6 - 6
models/line_detect/loi_heads.py

@@ -1213,14 +1213,14 @@ class RoIHeads(nn.Module):
                 img_size = h
 
                 gt_arcs_tensor = torch.zeros(0, 0)
-                if len(gt_arcs) > 0:
-                    gt_arcs_tensor = torch.cat(gt_arcs)
-                    print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
+                # if len(gt_arcs) > 0:
+                    # gt_arcs_tensor = torch.cat(gt_arcs)
+                    # print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
 
-                if gt_arcs_tensor.shape[0] > 0:
-                    print(f'start to compute point_loss')
+                # if gt_arcs_tensor.shape[0] > 0:
+                #     print(f'start to compute point_loss')
 
-                    loss_arc=compute_arc_loss(feature_logits,arc_proposals,gt_arcs,arc_pos_matched_idxs)
+                loss_arc=compute_arc_loss(feature_logits,arc_proposals,gt_arcs,arc_pos_matched_idxs)
 
                 if loss_arc is None:
                     print(f'loss_arc is None111')

+ 2 - 1
models/line_detect/train.yaml

@@ -4,6 +4,7 @@ io:
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
   datadir: \\192.168.50.222/share/zyh/arc/a_dataset
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
+#  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
@@ -14,7 +15,7 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 2
+  batch_size: 3
   max_epoch: 8000000
 #  augmentation: True
   augmentation: False