Kaynağa Gözat

add resnet152fpn train lines

lstrlq 5 ay önce
ebeveyn
işleme
1bf78724b0

+ 38 - 1
models/line_detect/line_detect.py

@@ -414,7 +414,44 @@ def linedetect_newresnet101fpn(
     if num_points is None:
         num_points = 3
 
+    size=768
+    backbone =resnet101fpn(out_channels=256)
+    featmap_names=['0', '1', '2', '3','4','pool']
+    # print(f'featmap_names:{featmap_names}')
+    roi_pooler = MultiScaleRoIAlign(
+        featmap_names=featmap_names,
+        output_size=7,
+        sampling_ratio=2
+    )
+    num_features=len(featmap_names)
+    anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features))  # 自动生成不同大小
+    # print(f'anchor_sizes:{anchor_sizes}')
+    aspect_ratios = ((0.5, 1.0, 2.0),) * num_features
+    # print(f'aspect_ratios:{aspect_ratios}')
+
+
+    anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
+
+    model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
 
+    return model
+
+def linedetect_newresnet152fpn(
+        *,
+
+        num_classes: Optional[int] = None,
+        num_points:Optional[int] = None,
+
+        **kwargs: Any,
+) -> LineDetect:
+    # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
+    # weights_backbone = ResNet50_Weights.verify(weights_backbone)
+    if num_classes is None:
+        num_classes = 3
+    if num_points is None:
+        num_points = 3
+
+    size=1024
     backbone =resnet101fpn(out_channels=256)
     featmap_names=['0', '1', '2', '3','4','pool']
     # print(f'featmap_names:{featmap_names}')
@@ -432,7 +469,7 @@ def linedetect_newresnet101fpn(
 
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
 
-    model = LineDetect(backbone, num_classes, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
+    model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
 
     return model
 

+ 5 - 5
models/line_detect/train.yaml

@@ -1,12 +1,12 @@
 io:
   logdir: train_results
 
-#  datadir: /data/share/rlq/datasets/250718caisegangban
+  datadir: /data/share/rlq/datasets/250718caisegangban
 #  datadir: /data/share/zjh/Dataset_correct_xanylabel_tiff
 
 #  datadir: /data/share/rlq/datasets/250718caisegangban
 
-  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
+#  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/zyh/arc/a_dataset
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
@@ -20,10 +20,10 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 3
+  batch_size: 2
   max_epoch: 8000000
-#  augmentation: True
-  augmentation: False
+  augmentation: True
+#  augmentation: False
   optim:
     name: Adam
     lr: 4.0e-4

+ 3 - 2
models/line_detect/train_demo.py

@@ -2,7 +2,7 @@ import torch
 
 from models.line_detect.line_detect import linedetect_newresnet18fpn, linedetect_resnet50_fpn, linedetect_resnet18_fpn, \
     linedetect_newresnet50fpn, linedetect_maxvitfpn, linedetect_high_maxvitfpn, linedetect_swin_transformer_fpn, \
-    linedetect_newresnet101fpn
+    linedetect_newresnet101fpn, linedetect_newresnet152fpn
 
 from models.line_net.trainer import Trainer
 
@@ -18,8 +18,9 @@ if __name__ == '__main__':
 
     # model=linedetect_resnet18_fpn()
     # model=linedetect_newresnet18fpn(num_points=3)
-    model=linedetect_newresnet50fpn(num_points=3)
+    # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet101fpn(num_points=3)
+    model = linedetect_newresnet152fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
     # model=linedetect_high_maxvitfpn()