Prechádzať zdrojové kódy

添加原分辨率特征图的resnet18fpn

RenLiqiang 5 mesiacov pred
rodič
commit
d80c57019b

+ 38 - 0
models/line_detect/line_detect.py

@@ -327,6 +327,44 @@ class LinePredictor(nn.Module):
             x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
         )
 
+def linedetect_newresnet18fpn(
+        *,
+
+        num_classes: Optional[int] = None,
+        num_points:Optional[int] = None,
+
+        **kwargs: Any,
+) -> LineDetect:
+    # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
+    # weights_backbone = ResNet50_Weights.verify(weights_backbone)
+    if num_classes is None:
+        num_classes = 2
+    if num_points is None:
+        num_points = 2
+
+
+    backbone =resnet18fpn()
+    featmap_names=['0', '1', '2', '3','pool']
+    # print(f'featmap_names:{featmap_names}')
+    roi_pooler = MultiScaleRoIAlign(
+        featmap_names=featmap_names,
+        output_size=7,
+        sampling_ratio=2
+    )
+    num_features=len(featmap_names)
+    anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features))  # 自动生成不同大小
+    # print(f'anchor_sizes:{anchor_sizes}')
+    aspect_ratios = ((0.5, 1.0, 2.0),) * num_features
+    # print(f'aspect_ratios:{aspect_ratios}')
+
+
+    anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
+
+    model = LineDetect(backbone, num_classes, num_keypoints=num_points,rpn_anchor_generator=anchor_generator,box_roi_pool=roi_pooler, **kwargs)
+
+    return model
+
+
 
 def lineDetect_resnet18_fpn(
         *,

+ 4 - 2
models/line_detect/train_demo.py

@@ -1,6 +1,6 @@
 import torch
 
-from models.line_detect.line_detect import lineDetect_resnet18_fpn
+from models.line_detect.line_detect import linedetect_newresnet18fpn
 from models.line_net.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, linenet_newresnet50fpn, \
     get_line_net_convnext_fpn, linenet_newresnet18fpn
 from models.line_net.trainer import Trainer
@@ -13,6 +13,8 @@ if __name__ == '__main__':
     # model = linenet_resnet18_fpn()
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
     # model=linenet_newresnet50fpn()
-    model = lineDetect_resnet18_fpn()
+    # model = lineDetect_resnet18_fpn()
+
+    model=linedetect_newresnet18fpn()
 
     model.start_train(cfg='train.yaml')

+ 1 - 0
models/line_net/roi_heads.py

@@ -1048,6 +1048,7 @@ class RoIHeads(nn.Module):
                         raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
 
         if self.training:
+            print(f'targets:{targets}')
             proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
         else:
             if targets is not None: