浏览代码

debug train line_detect_maxvitfpn

lstrlq 5 月之前
父节点
当前提交
8ac0f02074

+ 5 - 4
models/line_detect/line_detect.py

@@ -412,7 +412,9 @@ def linedetect_maxvitfpn(
     if num_points is None:
     if num_points is None:
         num_points = 3
         num_points = 3
 
 
-    maxvit = MaxVitBackbone(input_size=(224*2,224*2))
+    size=224*2
+
+    maxvit = MaxVitBackbone(input_size=(size,size))
     # print(maxvit.named_children())
     # print(maxvit.named_children())
 
 
     # for i,layer in enumerate(maxvit.named_children()):
     # for i,layer in enumerate(maxvit.named_children()):
@@ -420,7 +422,6 @@ def linedetect_maxvitfpn(
 
 
     in_channels_list = [64, 64, 128, 256, 512]
     in_channels_list = [64, 64, 128, 256, 512]
     featmap_names = ['0', '1', '2', '3', '4', 'pool']
     featmap_names = ['0', '1', '2', '3', '4', 'pool']
-    # print(f'featmap_names:{featmap_names}')
     roi_pooler = MultiScaleRoIAlign(
     roi_pooler = MultiScaleRoIAlign(
         featmap_names=featmap_names,
         featmap_names=featmap_names,
         output_size=7,
         output_size=7,
@@ -437,8 +438,8 @@ def linedetect_maxvitfpn(
 
 
     model = LineDetect(
     model = LineDetect(
         backbone=backbone_with_fpn,
         backbone=backbone_with_fpn,
-        min_size=224 * 2,
-        max_size=224 * 2,
+        min_size=size,
+        max_size=size,
         num_classes=91,  # COCO 数据集有 91 类
         num_classes=91,  # COCO 数据集有 91 类
         rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input),
         rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input),
         box_roi_pool=roi_pooler
         box_roi_pool=roi_pooler

+ 3 - 0
models/line_detect/loi_heads.py

@@ -1623,6 +1623,9 @@ class RoIHeads(nn.Module):
                     if not loss_line_iou :
                     if not loss_line_iou :
                         loss_line_iou=torch.tensor(0.0,device=cs_features.device)
                         loss_line_iou=torch.tensor(0.0,device=cs_features.device)
 
 
+                    if not loss_point:
+                        loss_point=torch.tensor(0.0,device=cs_features.device)
+
                     loss_line = {"loss_line": loss_line}
                     loss_line = {"loss_line": loss_line}
                     loss_line_iou = {'loss_line_iou': loss_line_iou}
                     loss_line_iou = {'loss_line_iou': loss_line_iou}
                     loss_point={"loss_point":loss_point}
                     loss_point={"loss_point":loss_point}

+ 1 - 1
models/line_detect/train.yaml

@@ -11,7 +11,7 @@ io:
 train_params:
 train_params:
   resume_from:
   resume_from:
   num_workers: 8
   num_workers: 8
-  batch_size: 1
+  batch_size: 2
   max_epoch: 80000
   max_epoch: 80000
 #  augmentation: True
 #  augmentation: True
   augmentation: False
   augmentation: False

+ 2 - 2
models/line_detect/trainer.py

@@ -205,14 +205,14 @@ class Trainer(BaseTrainer):
         self.writer.add_image("z-obj", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
         self.writer.add_image("z-obj", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
 
 
 
 
-        if type==1:
+        if type==1 and 'points' in result:
             keypoint_img = draw_keypoints(boxed_image, result['points'], colors='red', width=3)
             keypoint_img = draw_keypoints(boxed_image, result['points'], colors='red', width=3)
 
 
             self.writer.add_image("z-output", keypoint_img, epoch)
             self.writer.add_image("z-output", keypoint_img, epoch)
         # print("lines shape:", result['lines'].shape)
         # print("lines shape:", result['lines'].shape)
 
 
 
 
-        if type==2:
+        if type==2 and 'lines' in result:
             # 用自己写的函数画线段
             # 用自己写的函数画线段
             # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
             # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
             print(f"shape of linescore:{result['liness_scores'].shape}")
             print(f"shape of linescore:{result['liness_scores'].shape}")