Przeglądaj źródła

debug train line on 4080

lstrlq 5 miesięcy temu
rodzic
commit
ab087b6767

+ 26 - 14
models/line_detect/loi_heads.py

@@ -983,7 +983,6 @@ class RoIHeads(nn.Module):
 
 
 
-
                 if gt_lines_tensor.shape[0]>0 :
                     print(f'start to lines_point_pair_loss')
                     loss_line = lines_point_pair_loss(
@@ -997,10 +996,12 @@ class RoIHeads(nn.Module):
                         point_features, point_proposals, gt_points, point_pos_matched_idxs
                     )
 
-                if not loss_line:
+                if  loss_line is None:
+                    print(f'loss_line is None111')
                     loss_line = torch.tensor(0.0, device=cs_features.device)
 
-                if not loss_line_iou:
+                if loss_line_iou is None:
+                    print(f'loss_line_iou is None111')
                     loss_line_iou = torch.tensor(0.0, device=cs_features.device)
 
                 loss_line = {"loss_line": loss_line}
@@ -1015,20 +1016,25 @@ class RoIHeads(nn.Module):
                 if targets is not None:
                     h, w = targets[0]["img_size"]
                     img_size = h
-                    gt_lines = [t["lines"] for t in targets]
-                    gt_points = [t["points"] for t in targets]
+                    gt_lines = [t["lines"] for t in targets if "lines" in t]
+                    gt_points = [t["points"] for t in targets if "points" in t]
 
-                    gt_lines_tensor = torch.cat(gt_lines)
-                    gt_points_tensor = torch.cat(gt_points)
+                    gt_lines_tensor = torch.zeros(0, 0)
+                    gt_points_tensor = torch.zeros(0, 0)
+                    if len(gt_lines)>0:
+                        gt_lines_tensor = torch.cat(gt_lines)
+                    if len(gt_points)>0:
+                        gt_points_tensor = torch.cat(gt_points)
 
-                    line_pos_matched_idxs = []
-                    point_pos_matched_idxs = []
+                    # line_pos_matched_idxs = []
+                    # point_pos_matched_idxs = []
 
 
                     if gt_lines_tensor.shape[0] > 0 and line_features is not None:
                         loss_line = lines_point_pair_loss(
                             line_features, line_proposals, gt_lines, line_pos_matched_idxs
                         )
+                        print(f'compute_line_loss:{loss_line}')
                         loss_line_iou = line_iou_loss(line_features , line_proposals, gt_lines, line_pos_matched_idxs,
                                                       img_size)
 
@@ -1037,14 +1043,17 @@ class RoIHeads(nn.Module):
                             point_features, point_proposals, gt_points, point_pos_matched_idxs
                         )
 
-                    if not loss_line :
+                    if  loss_line is None:
+                        print(f'loss_line is None')
                         loss_line=torch.tensor(0.0,device=cs_features.device)
 
-                    if not loss_line_iou :
+                    if  loss_line_iou is None:
+                        print(f'loss_line_iou is None')
                         loss_line_iou=torch.tensor(0.0,device=cs_features.device)
 
-                    if not loss_point:
-                        loss_point=torch.tensor(0.0,device=cs_features.device)
+                    # if  loss_point is None:
+                    #     print(f'loss_point is None')
+                    #     loss_point=torch.tensor(0.0,device=cs_features.device)
 
                     loss_line = {"loss_line": loss_line}
                     loss_line_iou = {'loss_line_iou': loss_line_iou}
@@ -1057,6 +1066,9 @@ class RoIHeads(nn.Module):
 
 
                 else:
+                    loss_point = {}
+                    loss_line = {}
+                    loss_line_iou = {}
                     if feature_logits is None or line_proposals is None:
                         raise ValueError(
                             "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
@@ -1077,7 +1089,7 @@ class RoIHeads(nn.Module):
                             r["points_scores"] = ps
 
 
-
+            print(f'loss_line11111:{loss_line}')
             losses.update(loss_line)
             losses.update(loss_line_iou)
             losses.update(loss_point)

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: \\192.168.50.222/share/zjh/Dataset_correct_xanylabel
+  datadir: /data/share/rlq/datasets/Dataset_correct_xanylabel
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000

+ 4 - 4
models/line_detect/trainer.py

@@ -191,7 +191,7 @@ class Trainer(BaseTrainer):
 
 
 
-    def writer_predict_result(self, img, result, epoch,type=1):
+    def writer_predict_result(self, img, result, epoch, typ=1):
         img = img.cpu().detach()
         im = img.permute(1, 2, 0)  # [512, 512, 3]
         self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
@@ -205,14 +205,14 @@ class Trainer(BaseTrainer):
         self.writer.add_image("z-obj", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
 
 
-        if type==1 and 'points' in result:
+        if typ==1 and 'points' in result:
             keypoint_img = draw_keypoints(boxed_image, result['points'], colors='red', width=3)
 
             self.writer.add_image("z-output", keypoint_img, epoch)
         # print("lines shape:", result['lines'].shape)
 
 
-        if type==2 and 'lines' in result:
+        if typ==2 and 'lines' in result:
             # 用自己写的函数画线段
             # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
             print(f"shape of linescore:{result['liness_scores'].shape}")
@@ -341,7 +341,7 @@ class Trainer(BaseTrainer):
                 # print(f'result:{result}')
                 t_end = time.time()
                 print(f'predict used:{t_end - t_start}')
-                self.writer_predict_result(img=imgs[0], result=result[0], epoch=epoch)
+                self.writer_predict_result(img=imgs[0], result=result[0], typ=2, epoch=epoch)
                 epoch_step+=1
 
         avg_loss = total_loss / len(data_loader)