Forráskód Böngészése

修改提取正圆4个点逻辑,有初步较好效果

RenLiqiang 4 hónapja
szülő
commit
ce4aa432ca

+ 0 - 0
models/line_detect/heads/arc/__init__.py


+ 0 - 0
models/line_detect/heads/circle/__init__.py


+ 217 - 11
models/line_detect/heads/head_losses.py

@@ -3,6 +3,7 @@ from matplotlib import pyplot as plt
 
 import torch.nn.functional as F
 from torch import nn
+from torch.cuda import device
 
 
 class DiceLoss(nn.Module):
@@ -438,12 +439,211 @@ def heatmaps_to_points(maps, rois,num_points=2):
 
     return point_preds,point_end_scores
 
+# 分4块
+def find_max_heat_point_in_each_part(feature_map, box):
+    """
+    在给定的特征图上,根据box中心点往上移3,往右移3作为新的中心点,
+    并将特征图划分为4个部分,之后在每个部分中找到热度值最大的点。
+
+    Args:
+        feature_map (torch.Tensor): 形状为 [C, H, W] 的特征图
+        box (torch.Tensor): 形状为 [4] 的边界框 [x_min, y_min, x_max, y_max]
+
+    Returns:
+        list: 每个区域中热度最高的点的位置和其对应的热度值 [(y1, x1, heat1), ..., (y4, x4, heat4)]
+    """
+    device = feature_map.device
+    C, H, W = feature_map.shape
+
+    # 计算box的中心点(cx, cy)
+    cx = (box[0] + box[2]) // 2
+    cy = (box[1] + box[3]) // 2
+
+    # 偏移中心点
+    new_cx = min(max(cx + 3, 0), W - 1)  # 向右移3
+    new_cy = min(max(cy - 3, 0), H - 1)  # 向上移3
+
+    # 创建坐标网格
+    y_coords, x_coords = torch.meshgrid(
+        torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij'
+    )
+
+    # 划分四个区域
+    mask_q1 = (y_coords < new_cy) & (x_coords < new_cx)  # 左上
+    mask_q2 = (y_coords < new_cy) & (x_coords >= new_cx)  # 右上
+    mask_q3 = (y_coords >= new_cy) & (x_coords < new_cx)  # 左下
+    mask_q4 = (y_coords >= new_cy) & (x_coords >= new_cx)  # 右下
+
+    # def process_region(mask):
+    #     region = feature_map[:, :, mask].squeeze()
+    #     if len(region.shape) == 0:  # 如果区域为空,则跳过
+    #         return None, None
+    #     # 找到最大热度值的点及其位置
+    #     (y, x), heat_val = non_maximum_suppression(region[0])
+    #     # 将相对坐标转换回全局坐标
+    #     y_global = y + torch.where(mask)[0].min().item()
+    #     x_global = x + torch.where(mask)[1].min().item()
+    #     return (y_global, x_global), heat_val
+    #
+    # results = []
+    # for mask in [mask_q1, mask_q2, mask_q3, mask_q4]:
+    #     point, heat_val = process_region(mask)
+    #     if point is not None:
+    #         # results.append((point[0], point[1], heat_val))
+    #         results.append((point[0], point[1]))
+    #     else:
+    #         results.append(None)
+    masks = [mask_q1, mask_q2, mask_q3, mask_q4]
+    results = []
+
+    # 假设使用第一个通道作为热力图
+    heatmap = feature_map[0]  # [H, W]
+
+    def process_region(mask):
+        # 应用 mask,只保留该区域
+        masked_heatmap = heatmap.clone()  # 复制以避免修改原数据
+        masked_heatmap[~mask] = 0  # 非区域置0
+
+        def non_maximum_suppression_2d(heatmap, kernel_size=3):
+            """
+            对 2D 热力图做非极大值抑制,保留局部最大值点。
+
+            Args:
+                heatmap (torch.Tensor): [H, W],输入热力图
+                kernel_size (int): 池化窗口大小,用于比较是否为局部最大值
+
+            Returns:
+                torch.Tensor: 与 heatmap 同形状的 mask,局部最大值位置为 True
+            """
+            pad = (kernel_size - 1) // 2
+            max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
+            maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
+            # 局部最大值且值大于0
+            peaks = (heatmap == maxima) & (heatmap > 0)
+            return peaks
+
+        # 1. 先做 NMS 得到候选局部极大值点
+        nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3)  # [H, W] bool
+        candidate_peaks = masked_heatmap * nms_mask.float()  # 只保留 NMS 后的峰值
+
+        # 2. 找出所有候选点中值最大的一个
+        if candidate_peaks.max() <= 0:
+            return None
+
+        # 找到最大值的位置
+        max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
+        y, x = divmod(max_idx.item(), W)
+
+        return (x, y)  # 返回 (y, x)
+
+    for mask in masks:
+        point = process_region(mask)
+        results.append(point)
+
+    return results
+
+
+def non_maximum_suppression_2d(heatmap, kernel_size=3):
+    pad = (kernel_size - 1) // 2
+    max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
+    maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
+    peaks = (heatmap == maxima) & (heatmap > 0)
+    return peaks
+
+def find_max_heat_point_in_edge_centers(feature_map, box):
+
+    device = feature_map.device
+    C, H, W = feature_map.shape
+
+    # ¼ÆËã box ÖÐÐÄ
+    cx = (box[0] + box[2]) / 2
+    cy = (box[1] + box[3]) / 2
+
+    # ¸ù¾Ý box ¿í¸ß¼ÆËã¾Å¹¬¸ñ·Ö½çÏß
+    box_width = box[2] - box[0]
+    box_height = box[3] - box[1]
+
+    x_left = cx - box_width / 6
+    x_right = cx + box_width / 6
+    y_top = cy - box_height / 6
+    y_bottom = cy + box_height / 6
+
+    # ´´½¨Íø¸ñ
+    y_coords, x_coords = torch.meshgrid(
+        torch.arange(H, device=device),
+        torch.arange(W, device=device),
+        indexing='ij'
+    )
+
+    # ¶¨ÒåËĸö¡°±ßÖС±ÇøÓòµÄ mask
+    mask1 = (x_coords < x_left) & (y_coords < y_top)
+    mask_top_middle    = (x_coords >= x_left) & (x_coords < x_right) & (y_coords < y_top)
+    mask3 = (x_coords >= x_right) & (y_coords < y_top)
+
+    mask_left_middle = (x_coords < x_left) & (y_coords >= y_top) & (y_coords < y_bottom)
+    mask_right_middle  = (x_coords >= x_right) & (y_coords >= y_top) & (y_coords < y_bottom)
+
+    mask4 = (x_coords < x_left) & (y_coords >= y_bottom)
+    mask_bottom_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords >= y_bottom)
+    mask_right_bottom = (x_coords >= x_right) & (y_coords >= y_bottom)
+
+    # masks = [
+    #     # mask1,
+    #     mask_top_middle,
+    #     # mask3,
+    #     mask_left_middle,
+    #     mask_right_middle,
+    #     # mask4,
+    #     mask_bottom_middle,
+    #     mask_right_bottom
+    # ]
+
+    masks = [
+        mask_top_middle,
+        mask_right_middle,
+        mask_bottom_middle,
+        mask_left_middle
+    ]
+
+    # ʹÓõÚÒ»¸öͨµÀ×÷ΪÈÈÁ¦Í¼
+    heatmap = feature_map[0]  # [H, W]
+
+    results = []
+
+    for mask in masks:
+        masked_heatmap = heatmap.clone()
+        masked_heatmap[~mask] = 0  # ·ÇÄ¿±êÇøÓòÖà 0
+
+        # # NMS ÒÖÖÆ
+        # nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3)
+        # candidate_peaks = masked_heatmap * nms_mask.float()
+        #
+        # if candidate_peaks.max() <= 0:
+        #     results.append(None)
+        #     continue
+        #
+        # # ÕÒ×î´óֵλÖÃ
+        # max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
+        # y, x = divmod(max_idx.item(), W)
+        flatten_point_roi_map = masked_heatmap.reshape(1, -1)
+        point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
+        point_x =point_index % W
+        point_y = torch.div(point_index - point_x, W, rounding_mode="floor")
+
+        results.append((point_x, point_y))
+
+    return results  # [(y_top, x_top), (y_right, x_right), (y_bottom, x_bottom), (y_left, x_left)]
+
+
+
 def heatmaps_to_circle_points(maps, rois,num_points=2):
 
 
     point_preds = torch.zeros((len(rois), 4, 2), dtype=torch.float32, device=maps.device)
     point_end_scores = torch.zeros((len(rois),4, 1), dtype=torch.float32, device=maps.device)
 
+    print(f'rois in heatmaps_to_circle_points:{type(rois),  rois.shape}')   # <class 'torch.Tensor'>
+
     print(f'heatmaps_to_lines:{maps.shape}')
     point_maps=maps[:,0]
     print(f'point_map:{point_maps.shape}')
@@ -452,18 +652,24 @@ def heatmaps_to_circle_points(maps, rois,num_points=2):
         point_roi_map = point_maps[i].unsqueeze(0)
         print(f'point_roi_map:{point_roi_map.shape}')
         # roi_map_probs = scores_to_probs(roi_map.copy())
-        w = point_roi_map.shape[2]
-        flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
-        point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
-        print(f'point index:{point_index}')
-        # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
-
-        point_x =point_index % w
-        point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
-
 
-        point_preds[i, :,0] = point_x
-        point_preds[i, :,1] = point_y
+        # w = point_roi_map.shape[2]
+        # flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
+        # print(f'non_maximum_suppression :{non_maximum_suppression(point_roi_map).shape}')
+        # point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
+        # print(f'point index:{point_index}')
+        # point_x =point_index % w
+        # point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
+        # print(f'point_x:{point_x}, point_y:{point_y}')
+        # point_preds[i, :, 0] = point_x
+        # point_preds[i, :, 1] = point_y
+        roi1=rois[i]
+        result_points = find_max_heat_point_in_edge_centers(non_maximum_suppression(point_roi_map), roi1)
+
+        point_preds[i, :]=torch.tensor(result_points)
+
+        point_x = [point[0] for point in result_points]
+        point_y = [point[1] for point in result_points]
 
         point_end_scores[i, :,0] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
 

+ 0 - 0
models/line_detect/heads/line/__init__.py


+ 0 - 0
models/line_detect/heads/point/__init__.py


+ 4 - 3
models/line_detect/line_detect.py

@@ -351,9 +351,9 @@ def linedetect_newresnet18fpn(
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     if num_classes is None:
-        num_classes = 4
+        num_classes = 5
     if num_points is None:
-        num_points = 3
+        num_points = 4
 
     size=512
     backbone =resnet18fpn()
@@ -378,7 +378,8 @@ def linedetect_newresnet18fpn(
                        rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler,
                        detect_point=False,
                        detect_line=False,
-                       detect_arc=True,
+                       detect_arc=False,
+                       detect_circle=True,
 
                        **kwargs)
 

+ 1 - 1
models/line_detect/train.yaml

@@ -7,7 +7,7 @@ io:
 #  datadir: /data/share/rlq/datasets/250718caisegangban
 
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
-  datadir: /data/share/zyh/data/rgb_4point/a_dataset
+  datadir: \\192.168.50.222/share/rlq/datasets/guanban_circle
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
   data_type: rgb

+ 2 - 2
models/line_detect/train_demo.py

@@ -17,8 +17,8 @@ if __name__ == '__main__':
     # model = lineDetect_resnet18_fpn()
 
     # model=linedetect_resnet18_fpn()
-    # model=linedetect_newresnet18fpn(num_points=3)
-    model=linedetect_newresnet50fpn(num_points=4)
+    model=linedetect_newresnet18fpn(num_points=4)
+    # model=linedetect_newresnet50fpn(num_points=4)
     # model = linedetect_newresnet101fpn(num_points=3)
     # model = linedetect_newresnet152fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')