Преглед на файлове

尝试修复训练时proposal_per_img 为(0,4)的bug

lstrlq преди 5 месеца
родител
ревизия
ab7c9a8415
променени са 3 файла, в които са добавени 16 реда и са изтрити 12 реда
  1. 5 3
      models/line_detect/heads/head_losses.py
  2. 4 4
      models/line_detect/train.yaml
  3. 7 5
      models/line_detect/train_demo.py

+ 5 - 3
models/line_detect/heads/head_losses.py

@@ -381,9 +381,11 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
         print(f'line_proposals_per_image:{proposals_per_image.shape}')
         print(f'gt_lines:{gt_lines}')
-        kp = gt_kp_in_image[midx]
-        gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
-        gs_heatmaps.append(gs_heatmaps_per_img)
+        if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
+            kp = gt_kp_in_image[midx]
+
+            gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
+            gs_heatmaps.append(gs_heatmaps_per_img)
         # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
 
         # heatmaps.append(heatmaps_per_image.view(-1))

+ 4 - 4
models/line_detect/train.yaml

@@ -12,10 +12,10 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 1
-  max_epoch: 80000
-#  augmentation: True
-  augmentation: False
+  batch_size: 2
+  max_epoch: 8000000
+  augmentation: True
+#  augmentation: False
   optim:
     name: Adam
     lr: 4.0e-4

+ 7 - 5
models/line_detect/train_demo.py

@@ -1,7 +1,8 @@
 import torch
 
 from models.line_detect.line_detect import linedetect_newresnet18fpn, linedetect_resnet50_fpn, linedetect_resnet18_fpn, \
-    linedetect_newresnet50fpn, linedetect_maxvitfpn, linedetect_high_maxvitfpn, linedetect_swin_transformer_fpn
+    linedetect_newresnet50fpn, linedetect_maxvitfpn, linedetect_high_maxvitfpn, linedetect_swin_transformer_fpn, \
+    linedetect_newresnet101fpn
 
 from models.line_net.trainer import Trainer
 
@@ -16,10 +17,11 @@ if __name__ == '__main__':
     # model = lineDetect_resnet18_fpn()
 
     # model=linedetect_resnet18_fpn()
-    # model=linedetect_newresnet50fpn(num_points=3)
-    # model = linedetect_newresnet50fpn(num_points=3)
+    # model=linedetect_newresnet18fpn(num_points=3)
+    # model = linedetect_newresnet101fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
-    # model=linedetect_high_maxvitfpn()
-    model=linedetect_swin_transformer_fpn(type='t')
+    model=linedetect_high_maxvitfpn()
+    model.load_weights(r'/data/share/rlq/weights/250718maxvit_best_val.pth')
+    # model=linedetect_swin_transformer_fpn(type='t')
     model.start_train(cfg='train.yaml')