ソースを参照

debug single point on 4080

lstrlq 5 ヶ月 前
コミット
cf6ea4a3ad

+ 34 - 2
models/line_detect/loi_heads.py

@@ -634,6 +634,7 @@ def lines_features_align(features, proposals, img_size):
     print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}')
 
     align_feat_list = []
+
     for feat, proposals_per_img in zip(features, proposals):
         print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
         if proposals_per_img.shape[0]>0:
@@ -709,6 +710,18 @@ def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
     N, K, H, W = line_logits.shape
     len_proposals = len(proposals)
 
+    empty_count = 0
+    non_empty_count = 0
+
+    for prop in proposals:
+        if prop.shape[0] == 0:
+            empty_count += 1
+        else:
+            non_empty_count += 1
+
+    print(f"Empty proposals count: {empty_count}")
+    print(f"Non-empty proposals count: {non_empty_count}")
+
     print(f'starte to compute_point_loss')
     print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
     if H != W:
@@ -1479,11 +1492,27 @@ class RoIHeads(nn.Module):
 
 
             all_proposals=line_proposals+point_proposals
+            # print(f'point_proposals:{point_proposals}')
             # print(f'all_proposals:{all_proposals}')
+            for p in point_proposals:
+                print(f'point_proposal:{p.shape}')
+
+            for ap in all_proposals:
+                print(f'ap_proposal:{ap.shape}')
+
             filtered_proposals = [proposal for proposal in all_proposals if proposal.shape[0] > 0]
+            filtered_proposals_tensor=torch.cat(filtered_proposals)
+            line_proposals_tensor=torch.cat(line_proposals)
+
+            print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
+            print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
+
+            point_proposals_tensor=torch.cat(point_proposals)
+            print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
 
 
-            line_features = lines_features_align(line_features, filtered_proposals, image_shapes)
+            # line_features = lines_features_align(line_features, filtered_proposals, image_shapes)
+            line_features = lines_features_align(line_features, point_proposals, image_shapes)
             print(f'line_features from features_align:{line_features.shape}')
 
             line_features = self.line_head(line_features)
@@ -1515,7 +1544,7 @@ class RoIHeads(nn.Module):
                 gt_points_tensor = torch.cat(gt_points)
                 print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
                 print(f'gt_points_tensor:{gt_points_tensor.shape}')
-                if gt_lines_tensor.shape[0]>0:
+                if gt_lines_tensor.shape[0]>0 :
                     loss_line = lines_point_pair_loss(
                         line_logits, line_proposals, gt_lines, line_pos_matched_idxs
                     )
@@ -1542,9 +1571,12 @@ class RoIHeads(nn.Module):
                     img_size = h
                     gt_lines = [t["lines"] for t in targets]
                     gt_points = [t["points"] for t in targets]
+
                     gt_lines_tensor = torch.cat(gt_lines)
                     gt_points_tensor = torch.cat(gt_points)
 
+
+
                     if gt_lines_tensor.shape[0] > 0:
                         loss_line = lines_point_pair_loss(
                             line_logits, line_proposals, gt_lines, line_pos_matched_idxs

+ 2 - 2
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: \\192.168.50.222/share/rlq/datasets/Dataset0709_
+  datadir: /data/share/zjh/Dataset0709_2
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
@@ -11,7 +11,7 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 2
+  batch_size: 1
   max_epoch: 80000
 #  augmentation: True
   augmentation: False

+ 2 - 2
models/line_detect/trainer.py

@@ -318,7 +318,7 @@ class Trainer(BaseTrainer):
                 losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
 
                 print(f'val losses:{losses}')
-                print(f'val result:{result}')
+                # print(f'val result:{result}')
             else:
                 loss_dict = model(imgs, targets)
                 losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
@@ -338,7 +338,7 @@ class Trainer(BaseTrainer):
                 t_start = time.time()
                 print(f'start to predict:{t_start}')
                 result = model(self.move_to_device(imgs, self.device))
-                print(f'result:{result}')
+                # print(f'result:{result}')
                 t_end = time.time()
                 print(f'predict used:{t_end - t_start}')
                 self.writer_predict_result(img=imgs[0], result=result[0], epoch=epoch)