ソースを参照

train maxvitfpn on 4080

lstrlq 5 ヶ月 前
コミット
d3e5377a16

+ 5 - 2
models/line_detect/heads/head_losses.py

@@ -145,8 +145,11 @@ def line_points_to_heatmap(keypoints, rois, heatmap_size):
         # show_heatmap(roi_heatmap[0],'roi_heatmap')
         all_roi_heatmap.append(roi_heatmap)
 
-    all_roi_heatmap = torch.cat(all_roi_heatmap)
-    print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
+    if len(all_roi_heatmap) > 0:
+        all_roi_heatmap = torch.cat(all_roi_heatmap)
+        print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
+    else:
+        all_roi_heatmap = None
 
     return all_roi_heatmap
 

+ 0 - 2
models/line_detect/loi_heads.py

@@ -929,7 +929,6 @@ class RoIHeads(nn.Module):
                 gt_lines = [t["lines"] for t in targets if "lines" in t]
 
 
-
                 print(f'gt_lines:{gt_lines[0].shape}')
                 h, w = targets[0]["img_size"]
                 img_size = h
@@ -997,7 +996,6 @@ class RoIHeads(nn.Module):
 
 
                 else:
-                    loss_point = {}
                     loss_line = {}
                     loss_line_iou = {}
                     if feature_logits is None or line_proposals is None:

+ 2 - 2
models/line_detect/train.yaml

@@ -1,7 +1,7 @@
 io:
   logdir: train_results
-#  datadir: /data/share/rlq/datasets/Dataset_correct_xanylabel
-  datadir: \\192.168.50.222/share/rlq/datasets/Dataset_correct_xanylabel
+  datadir: /data/share/zjh/Dataset_correct_xanylabel
+#  datadir: \\192.168.50.222/share/rlq/datasets/Dataset_correct_xanylabel
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000