Selaa lähdekoodia

single point 和line 进行初步整合

lstrlq 5 kuukautta sitten
vanhempi
commit
e5e1c3d916

+ 57 - 0
models/base/high_reso_resnet.py

@@ -130,6 +130,50 @@ def resnet50fpn(out_channels=256):
         out_channels=out_channels,
     )
 
+def resnet101fpn(out_channels=256):
+    backbone = ResNet(Bottleneck,[3, 4, 23, 3])
+    return_layers = {
+        'encoder0': '0',
+        'encoder1': '1',
+        'encoder2': '2',
+        'encoder3': '3',
+        'encoder4': '4'
+    }
+
+    # in_channels_list = [self.inplanes, 64, 128, 256, 512]
+    in_channels_list = [64, 256, 512, 1024, 2048]
+    # in_channels_list = [64, 256, 512, 1024]
+
+    return BackboneWithFPN(
+        backbone,
+        return_layers=return_layers,
+        in_channels_list=in_channels_list,
+        out_channels=out_channels,
+    )
+
+
+def resnet152fpn(out_channels=256):
+    backbone = ResNet(Bottleneck,[3, 8, 36, 3])
+    return_layers = {
+        'encoder0': '0',
+        'encoder1': '1',
+        'encoder2': '2',
+        'encoder3': '3',
+        'encoder4': '4'
+    }
+
+    # in_channels_list = [self.inplanes, 64, 128, 256, 512]
+    in_channels_list = [64, 256, 512, 1024, 2048]
+    # in_channels_list = [64, 256, 512, 1024]
+
+    return BackboneWithFPN(
+        backbone,
+        return_layers=return_layers,
+        in_channels_list=in_channels_list,
+        out_channels=out_channels,
+    )
+
+
 
 class ResNet(nn.Module):
     def __init__(self, block: Type[Union[Bottleneck]], layers: List[int],):
@@ -147,6 +191,19 @@ class ResNet(nn.Module):
             nn.ReLU(inplace=True),
             nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
         )
+        # self.encoder0 = nn.Sequential(
+        #     nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, bias=False),
+        #     self._norm_layer(self.inplanes),
+        #     nn.ReLU(inplace=True),
+        #     nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, padding=1, bias=False),
+        #     self._norm_layer(self.inplanes),
+        #     nn.ReLU(inplace=True),
+        #     nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, padding=1, bias=False),
+        #     self._norm_layer(self.inplanes),
+        #     nn.ReLU(inplace=True),
+        #     nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
+        # )
+
         self.encoder1 = self._make_layer(block, 64, layers[0],stride=2)
         self.encoder2 = self._make_layer(block, 128, layers[1], stride=2)
         self.encoder3 = self._make_layer(block, 256, layers[2], stride=2)

+ 39 - 2
models/line_detect/line_detect.py

@@ -31,7 +31,7 @@ from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
 from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
 
-from ..base.high_reso_resnet import resnet50fpn, resnet18fpn
+from ..base.high_reso_resnet import resnet50fpn, resnet18fpn, resnet101fpn
 
 __all__ = [
     "LineDetect",
@@ -50,7 +50,7 @@ class LineDetect(BaseDetectionNet):
     def __init__(
             self,
             backbone,
-            num_classes=2,
+            num_classes=3,
             # transform parameters
             min_size=512,
             max_size=512,
@@ -362,6 +362,43 @@ def linedetect_newresnet50fpn(
 
     return model
 
+def linedetect_newresnet101fpn(
+        *,
+
+        num_classes: Optional[int] = None,
+        num_points:Optional[int] = None,
+
+        **kwargs: Any,
+) -> LineDetect:
+    # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
+    # weights_backbone = ResNet50_Weights.verify(weights_backbone)
+    if num_classes is None:
+        num_classes = 3
+    if num_points is None:
+        num_points = 3
+
+
+    backbone =resnet101fpn()
+    featmap_names=['0', '1', '2', '3','4','pool']
+    # print(f'featmap_names:{featmap_names}')
+    roi_pooler = MultiScaleRoIAlign(
+        featmap_names=featmap_names,
+        output_size=7,
+        sampling_ratio=2
+    )
+    num_features=len(featmap_names)
+    anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features))  # 自动生成不同大小
+    # print(f'anchor_sizes:{anchor_sizes}')
+    aspect_ratios = ((0.5, 1.0, 2.0),) * num_features
+    # print(f'aspect_ratios:{aspect_ratios}')
+
+
+    anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
+
+    model = LineDetect(backbone, num_classes, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
+
+    return model
+
 
 
 def linedetect_resnet18_fpn(

+ 35 - 19
models/line_detect/loi_heads.py

@@ -201,7 +201,7 @@ def single_point_to_heatmap(keypoints, rois, heatmap_size):
     y = keypoints[..., 1].unsqueeze(1)
 
 
-    gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=2.0)
+    gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
     # show_heatmap(gs[0],'target')
     all_roi_heatmap = []
     for roi, heatmap in zip(rois, gs):
@@ -648,9 +648,12 @@ def lines_features_align(features, proposals, img_size):
                 align_feat_list.append(align_feat)
 
     # print(f'align_feat_list:{align_feat_list}')
-    feats_tensor = torch.cat(align_feat_list)
+    if len(align_feat_list) > 0:
+        feats_tensor = torch.cat(align_feat_list)
 
-    print(f'align features :{feats_tensor.shape}')
+        print(f'align features :{feats_tensor.shape}')
+    else:
+        feats_tensor = None
 
     return feats_tensor
 
@@ -1487,7 +1490,7 @@ class RoIHeads(nn.Module):
 
             # print(f'line_features from line_roi_pool:{line_features.shape}')
             #(b,256,512,512)
-            line_features = self.channel_compress(features['0'])
+            cs_features = self.channel_compress(features['0'])
             #(b.8,512,512)
 
 
@@ -1501,26 +1504,38 @@ class RoIHeads(nn.Module):
                 print(f'ap_proposal:{ap.shape}')
 
             filtered_proposals = [proposal for proposal in all_proposals if proposal.shape[0] > 0]
-            filtered_proposals_tensor=torch.cat(filtered_proposals)
+            if len(filtered_proposals) > 0:
+                filtered_proposals_tensor=torch.cat(filtered_proposals)
+                print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
+
             line_proposals_tensor=torch.cat(line_proposals)
 
             print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
-            print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
+
 
             point_proposals_tensor=torch.cat(point_proposals)
             print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
 
 
             # line_features = lines_features_align(line_features, filtered_proposals, image_shapes)
-            line_features = lines_features_align(line_features, point_proposals, image_shapes)
-            print(f'line_features from features_align:{line_features.shape}')
 
-            line_features = self.line_head(line_features)
+            point_features = lines_features_align(cs_features, point_proposals, image_shapes)
+
+            line_features = lines_features_align(cs_features, line_proposals, image_shapes)
+
+
+
+
+
+
+            print(f'line_features from features_align:{cs_features.shape}')
+
+            cs_features = self.line_head(cs_features)
             #(N,1,512,512)
-            print(f'line_features from line_head:{line_features.shape}')
+            print(f'line_features from line_head:{cs_features.shape}')
             # line_logits = self.line_predictor(line_features)
 
-            line_logits = line_features
+            line_logits = cs_features
             print(f'line_logits:{line_logits.shape}')
 
             loss_line = {}
@@ -1540,26 +1555,27 @@ class RoIHeads(nn.Module):
                 #     line_logits, line_proposals, gt_lines, pos_matched_idxs
                 # )
                 # iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs, img_size)
+
                 gt_lines_tensor=torch.cat(gt_lines)
                 gt_points_tensor = torch.cat(gt_points)
                 print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
                 print(f'gt_points_tensor:{gt_points_tensor.shape}')
-                if gt_lines_tensor.shape[0]>0 :
+                if gt_lines_tensor.shape[0]>0  and line_features is not None:
                     loss_line = lines_point_pair_loss(
-                        line_logits, line_proposals, gt_lines, line_pos_matched_idxs
+                        line_features, line_proposals, gt_lines, line_pos_matched_idxs
                     )
                     loss_line_iou = line_iou_loss(line_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
 
-                if gt_points_tensor.shape[0]>0:
+                if gt_points_tensor.shape[0]>0 and point_features is not None:
                     loss_point = compute_point_loss(
-                        line_logits, point_proposals, gt_points, point_pos_matched_idxs
+                        point_features, point_proposals, gt_points, point_pos_matched_idxs
                     )
 
                 if not loss_line:
-                    loss_line = torch.tensor(0.0, device=line_features.device)
+                    loss_line = torch.tensor(0.0, device=cs_features.device)
 
                 if not loss_line_iou:
-                    loss_line_iou = torch.tensor(0.0, device=line_features.device)
+                    loss_line_iou = torch.tensor(0.0, device=cs_features.device)
 
                 loss_line = {"loss_line": loss_line}
                 loss_line_iou = {'loss_line_iou': loss_line_iou}
@@ -1590,10 +1606,10 @@ class RoIHeads(nn.Module):
                         )
 
                     if not loss_line :
-                        loss_line=torch.tensor(0.0,device=line_features.device)
+                        loss_line=torch.tensor(0.0,device=cs_features.device)
 
                     if not loss_line_iou :
-                        loss_line_iou=torch.tensor(0.0,device=line_features.device)
+                        loss_line_iou=torch.tensor(0.0,device=cs_features.device)
 
                     loss_line = {"loss_line": loss_line}
                     loss_line_iou = {'loss_line_iou': loss_line_iou}

+ 2 - 2
models/line_detect/train_demo.py

@@ -16,8 +16,8 @@ if __name__ == '__main__':
     # model = lineDetect_resnet18_fpn()
 
     # model=linedetect_resnet18_fpn()
-    # model=linedetect_newresnet18fpn(num_points=3)
-    model = linedetect_newresnet50fpn(num_points=3)
+    model=linedetect_newresnet50fpn(num_points=3)
+    # model = linedetect_newresnet50fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
 
     model.start_train(cfg='train.yaml')