Ver Fonte

debug circle mask,有初步效果

admin há 1 mês atrás
pai
commit
c96d57e9b1

+ 8 - 3
libs/vision_libs/models/detection/transform.py

@@ -218,10 +218,10 @@ class GeneralizedRCNNTransform(nn.Module):
             points = resize_keypoints(points, (h, w), image.shape[-2:])
             target["points"] = points
 
-        if "arc_mask" in target:
-            arc_mask = target["arc_mask"]
+        if "circle_masks" in target:
+            arc_mask = target["circle_masks"]
             arc_mask = resize_keypoints(arc_mask, (h, w), image.shape[-2:])
-            target["arc_mask"] = arc_mask
+            target["circle_masks"] = arc_mask
 
         if "circles" in target:
             arc_mask = target["circles"]
@@ -314,6 +314,11 @@ class GeneralizedRCNNTransform(nn.Module):
                 keypoints = pred["circles"]
                 keypoints = resize_keypoints(keypoints, im_s, o_im_s)
                 result[i]["circles"] = keypoints
+
+            if "circle_masks" in pred:
+                masks = pred["circle_masks"]
+                masks = paste_masks_in_image(masks, boxes, o_im_s)
+                result[i]["circle_masks"] = masks
         return result
 
     def __repr__(self) -> str:

+ 2 - 2
models/base/backbone_factory.py

@@ -103,13 +103,13 @@ def get_efficientnetv2_fpn(name='efficientnet_v2_m', pretrained=True):
         backbone = efficientnet_v2_l(weights=weights).features
 
     # 定义返回的层索引和名称
-    return_layers = {"1":"0", "2": "1", "3": "2", "4": "3", "5": "4"}
+    return_layers = {"1":"0", "2": "1", "3": "2", "4": "3", "6": "4"}
     input=torch.randn(1, 3, 512, 512)
     # out=backbone(input)
     # print(f'out:{out}')
     # 获取每个层输出通道数
     in_channels_list = []
-    for layer_idx in [1,2, 3, 4, 5]:
+    for layer_idx in [1,2, 3, 4, 6]:
         module = backbone[layer_idx]
         # print(f'efficientnet:{backbone}')
         if hasattr(module, 'out_channels'):

+ 22 - 0
models/base/transforms.py

@@ -49,6 +49,13 @@ class RandomHorizontalFlip:
                 lines[..., 0] = width - lines[..., 0]
                 target["lines"] = lines
 
+            # Flip lines
+            if "circle_masks" in target:
+                lines = target["circle_masks"].clone()
+                # 只翻转 x 坐标,y 和 visibility 不变
+                lines[..., 0] = width - lines[..., 0]
+                target["circle_masks"] = lines
+
         return img, target
 
 class RandomVerticalFlip:
@@ -74,6 +81,12 @@ class RandomVerticalFlip:
                 lines[..., 1] = height - lines[..., 1]
                 target["lines"] = lines
 
+
+            if "circle_masks" in target:
+                lines = target["circle_masks"].clone()
+                lines[..., 1] = height - lines[..., 1]
+                target["circle_masks"] = lines
+
         return img, target
 
 
@@ -128,6 +141,9 @@ class RandomResize:
         if "lines" in target:
             target["lines"] = target["lines"] * torch.tensor([scale, scale, 1], device=target["lines"].device)
 
+        if "circle_masks" in target:
+            target["circle_masks"] = target["circle_masks"] * torch.tensor([scale, scale, 1], device=target["circle_masks"].device)
+
         return img, target
 
 
@@ -280,6 +296,9 @@ class RandomRotation:
             if "lines" in target:
                 target["lines"] = self.rotate_lines(target["lines"], angle, center)
 
+            if "circle_masks" in target:
+                target["circle_masks"] = self.rotate_lines(target["circle_masks"], angle, center)
+
         return img, target
 
 
@@ -452,6 +471,9 @@ class RandomPerspective:
             if "lines" in target:
                 target["lines"] = self.perspective_lines(target["lines"], M, width, height)
 
+            if "circle_masks" in target:
+                target["circle_masks"] = self.perspective_lines(target["circle_masks"], M, width, height)
+
         return img, target
 
 class DefaultTransform(nn.Module):

+ 4 - 2
models/line_detect/heads/head_losses.py

@@ -1192,7 +1192,7 @@ import torch.nn.functional as F
 
 
 
-def heatmaps_to_arc(maps, rois, threshold=0, output_size=(128, 128)):
+def heatmaps_to_arc(maps, rois, threshold=0.5, output_size=(128, 128)):
     """
     Args:
         maps: [N, 3, H, W] - full heatmaps
@@ -1242,7 +1242,9 @@ def heatmaps_to_arc(maps, rois, threshold=0, output_size=(128, 128)):
         print(f"    roi_map_resized.shape: {roi_map_resized.shape}")
 
         # NMS + threshold
-        nms_roi = non_maximum_suppression(roi_map_resized)  # shape: [1, H, W]
+        # nms_roi = non_maximum_suppression(roi_map_resized)  # shape: [1, H, W]
+        nms_roi = torch.sigmoid(roi_map_resized)
+
         bin_mask = (nms_roi >= threshold).float()  # shape: [1, H, W]
         print(f"    bin_mask.sum(): {bin_mask.sum().item()}")
 

+ 2 - 2
models/line_detect/train.yaml

@@ -22,8 +22,8 @@ train_params:
   num_workers: 8
   batch_size: 2
   max_epoch: 8000000
-  augmentation: True
-#  augmentation: False
+#  augmentation: True
+  augmentation: False
   optim:
     name: Adam
     lr: 4.0e-4