Prechádzať zdrojové kódy

train maxvit for lines on 4080

lstrlq 5 mesiacov pred
rodič
commit
a2a168a126

+ 2 - 2
models/line_detect/line_detect.py

@@ -463,8 +463,8 @@ def linedetect_maxvitfpn(
         num_classes=3,  # COCO 数据集有 91 类
         rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input),
         box_roi_pool=roi_pooler,
-        detect_line=False,
-        detect_point=True,
+        detect_line=True,
+        detect_point=False,
     )
     return model
 

+ 1 - 1
models/line_detect/loi_heads.py

@@ -923,7 +923,7 @@ class RoIHeads(nn.Module):
                 else:
                     pos_matched_idxs = None
 
-            feature_logits = self.line_forward1(features, image_shapes, line_proposals)
+            feature_logits = self.line_forward3(features, image_shapes, line_proposals)
 
             loss_line = None
             loss_line_iou =None

+ 2 - 2
models/line_detect/train.yaml

@@ -16,8 +16,8 @@ train_params:
   num_workers: 8
   batch_size: 2
   max_epoch: 8000000
-#  augmentation: True
-  augmentation: False
+  augmentation: True
+#  augmentation: False
   optim:
     name: Adam
     lr: 4.0e-4

+ 3 - 3
models/line_detect/train_demo.py

@@ -21,8 +21,8 @@ if __name__ == '__main__':
     # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet101fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
-    # model=linedetect_maxvitfpn()
+    model=linedetect_maxvitfpn()
     # model=linedetect_high_maxvitfpn()
-    # model.load_weights(r'/data/share/rlq/weights/250718maxvit_best_val.pth')
-    model=linedetect_swin_transformer_fpn(type='t')
+    model.load_weights(r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250718_151833/weights/best_val.pth')
+    # model=linedetect_swin_transformer_fpn(type='t')
     model.start_train(cfg='train.yaml')