瀏覽代碼

尝试修复训练过程中proposals可能为0导致报错的bug

admin 4 月之前
父節點
當前提交
072c86c49c
共有 4 個文件被更改,包括 273 次插入232 次删除
  1. 23 11
      models/base/backbone_factory.py
  2. 3 3
      models/line_detect/line_detect.py
  3. 243 214
      models/line_detect/loi_heads.py
  4. 4 4
      models/line_detect/train_demo.py

+ 23 - 11
models/base/backbone_factory.py

@@ -103,12 +103,15 @@ def get_efficientnetv2_fpn(name='efficientnet_v2_m', pretrained=True):
         backbone = efficientnet_v2_l(weights=weights).features
 
     # 定义返回的层索引和名称
-    return_layers = {"2": "0", "3": "1", "4": "2", "5": "3"}
-
+    return_layers = {"1":"0", "2": "1", "3": "2", "4": "3", "5": "4"}
+    input=torch.randn(1, 3, 512, 512)
+    # out=backbone(input)
+    # print(f'out:{out}')
     # 获取每个层输出通道数
     in_channels_list = []
-    for layer_idx in [2, 3, 4, 5]:
+    for layer_idx in [1,2, 3, 4, 5]:
         module = backbone[layer_idx]
+        # print(f'efficientnet:{backbone}')
         if hasattr(module, 'out_channels'):
             in_channels_list.append(module.out_channels)
         elif hasattr(module[-1], 'out_channels'):
@@ -118,12 +121,19 @@ def get_efficientnetv2_fpn(name='efficientnet_v2_m', pretrained=True):
             raise ValueError(f"Cannot determine out_channels for layer {layer_idx}")
 
     # 使用BackboneWithFPN包装backbone
+    print(f'in_channels_list: {in_channels_list}')
     backbone_with_fpn = BackboneWithFPN(
         backbone=backbone,
         return_layers=return_layers,
         in_channels_list=in_channels_list,
         out_channels=256
     )
+    out=backbone_with_fpn(input)
+    print(f'out0:{out['0'].shape}')
+    print(f'out1:{out['1'].shape}')
+    print(f'out2:{out['2'].shape}')
+    print(f'out3:{out['3'].shape}')
+    print(f'out4:{out['4'].shape}')
 
     return backbone_with_fpn
 
@@ -279,14 +289,16 @@ def get_swin_transformer_fpn(type='t'):
     # print(f'out:{out}')
     return  backbone_with_fpn,roi_pooler,anchor_generator
 if __name__ == '__main__':
-    backbone_with_fpn, roi_pooler, anchor_generator=get_swin_transformer_fpn(type='t')
-    model=FasterRCNN(backbone=backbone_with_fpn,num_classes=3,box_roi_pool=roi_pooler,rpn_anchor_generator=anchor_generator)
-    input=torch.randn(3,3,512,512,device='cuda')
-    model.eval()
-    model.to('cuda')
-    out=model(input)
-    out=backbone_with_fpn(input)
-    print(f'out:{out.shape}')
+    # backbone_with_fpn, roi_pooler, anchor_generator=get_swin_transformer_fpn(type='t')
+    # model=FasterRCNN(backbone=backbone_with_fpn,num_classes=3,box_roi_pool=roi_pooler,rpn_anchor_generator=anchor_generator)
+    # input=torch.randn(3,3,512,512,device='cuda')
+    # model.eval()
+    # model.to('cuda')
+    # out=model(input)
+    # out=backbone_with_fpn(input)
+    # print(f'out:{out.shape}')
+    backbone=get_efficientnetv2_fpn(name='efficientnet_v2_l')
+    print(backbone)
 
     # # maxvit = models.maxvit_t(pretrained=True)
     # maxvit=MaxVitBackbone()

+ 3 - 3
models/line_detect/line_detect.py

@@ -432,7 +432,7 @@ def linedetect_newresnet101fpn(
     if num_points is None:
         num_points = 3
 
-    size=512
+    size=768
     backbone =resnet101fpn(out_channels=256)
     featmap_names=['0', '1', '2', '3','4','pool']
     # print(f'featmap_names:{featmap_names}')
@@ -474,7 +474,7 @@ def linedetect_newresnet152fpn(
     if num_points is None:
         num_points = 3
 
-    size=512
+    size=768
     backbone =resnet101fpn(out_channels=256)
     featmap_names=['0', '1', '2', '3','4','pool']
     # print(f'featmap_names:{featmap_names}')
@@ -659,7 +659,7 @@ def linedetect_swin_transformer_fpn(
         num_classes=3,  # COCO 数据集有 91 类
         rpn_anchor_generator=anchor_generator,
         box_roi_pool=roi_pooler,
-        detect_line=True,
+        detect_line=False,
         detect_point=False,
     )
     return model

+ 243 - 214
models/line_detect/loi_heads.py

@@ -961,105 +961,108 @@ class RoIHeads(nn.Module):
                 else:
                     pos_matched_idxs = None
 
-            feature_logits = self.line_forward3(features, image_shapes, line_proposals)
+            line_proposals_valid=self.check_proposals(line_proposals)
+            if line_proposals_valid:
 
-            loss_line = None
-            loss_line_iou =None
-
-            if self.training:
-
-                if targets is None or pos_matched_idxs is None:
-                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
-
-                gt_lines = [t["lines"] for t in targets if "lines" in t]
-
-
-                # print(f'gt_lines:{gt_lines[0].shape}')
-                h, w = targets[0]["img_size"]
-                img_size = h
-
-                gt_lines_tensor=torch.zeros(0,0)
-                if len(gt_lines)>0:
-                    gt_lines_tensor = torch.cat(gt_lines)
-                    print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
-
-
-                if gt_lines_tensor.shape[0]>0 :
-                    print(f'start to lines_point_pair_loss')
-                    loss_line = lines_point_pair_loss(
-                        feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
-                    )
-                    loss_line_iou = line_iou_loss(feature_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
+                feature_logits = self.line_forward3(features, image_shapes, line_proposals)
 
+                loss_line = None
+                loss_line_iou =None
 
+                if self.training:
 
-                if  loss_line is None:
-                    print(f'loss_line is None111')
-                    loss_line = torch.tensor(0.0, device=device)
+                    if targets is None or pos_matched_idxs is None:
+                        raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
 
-                if loss_line_iou is None:
-                    print(f'loss_line_iou is None111')
-                    loss_line_iou = torch.tensor(0.0, device=device)
+                    gt_lines = [t["lines"] for t in targets if "lines" in t]
 
-                loss_line = {"loss_line": loss_line}
-                loss_line_iou = {'loss_line_iou': loss_line_iou}
 
-            else:
-                if targets is not None:
+                    # print(f'gt_lines:{gt_lines[0].shape}')
                     h, w = targets[0]["img_size"]
                     img_size = h
-                    gt_lines = [t["lines"] for t in targets if "lines" in t]
 
-                    gt_lines_tensor = torch.zeros(0, 0)
+                    gt_lines_tensor=torch.zeros(0,0)
                     if len(gt_lines)>0:
                         gt_lines_tensor = torch.cat(gt_lines)
+                        print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
 
 
-
-                    if gt_lines_tensor.shape[0] > 0 and feature_logits is not None:
+                    if gt_lines_tensor.shape[0]>0 :
+                        print(f'start to lines_point_pair_loss')
                         loss_line = lines_point_pair_loss(
                             feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
                         )
-                        print(f'compute_line_loss:{loss_line}')
-                        loss_line_iou = line_iou_loss(feature_logits , line_proposals, gt_lines, line_pos_matched_idxs,
-                                                      img_size)
+                        loss_line_iou = line_iou_loss(feature_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
 
 
-                    if  loss_line is None:
-                        print(f'loss_line is None')
-                        loss_line=torch.tensor(0.0,device=device)
 
-                    if  loss_line_iou is None:
-                        print(f'loss_line_iou is None')
-                        loss_line_iou=torch.tensor(0.0,device=device)
+                    if  loss_line is None:
+                        print(f'loss_line is None111')
+                        loss_line = torch.tensor(0.0, device=device)
 
+                    if loss_line_iou is None:
+                        print(f'loss_line_iou is None111')
+                        loss_line_iou = torch.tensor(0.0, device=device)
 
                     loss_line = {"loss_line": loss_line}
                     loss_line_iou = {'loss_line_iou': loss_line_iou}
-                    
-
 
                 else:
-                    loss_line = {}
-                    loss_line_iou = {}
-                    if feature_logits is None or line_proposals is None:
-                        raise ValueError(
-                            "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
-                        )
+                    if targets is not None:
+                        h, w = targets[0]["img_size"]
+                        img_size = h
+                        gt_lines = [t["lines"] for t in targets if "lines" in t]
+
+                        gt_lines_tensor = torch.zeros(0, 0)
+                        if len(gt_lines)>0:
+                            gt_lines_tensor = torch.cat(gt_lines)
+
+
 
-                    if feature_logits is not None:
-                        lines_probs, lines_scores = line_inference(feature_logits,line_proposals)
-                        for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
-                            r["lines"] = keypoint_prob
-                            r["lines_scores"] = kps
+                        if gt_lines_tensor.shape[0] > 0 and feature_logits is not None:
+                            loss_line = lines_point_pair_loss(
+                                feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
+                            )
+                            print(f'compute_line_loss:{loss_line}')
+                            loss_line_iou = line_iou_loss(feature_logits , line_proposals, gt_lines, line_pos_matched_idxs,
+                                                          img_size)
 
 
+                        if  loss_line is None:
+                            print(f'loss_line is None')
+                            loss_line=torch.tensor(0.0,device=device)
 
+                        if  loss_line_iou is None:
+                            print(f'loss_line_iou is None')
+                            loss_line_iou=torch.tensor(0.0,device=device)
 
-            print(f'loss_line11111:{loss_line}')
-            losses.update(loss_line)
-            losses.update(loss_line_iou)
-            print(f'losses:{losses}')
+
+                        loss_line = {"loss_line": loss_line}
+                        loss_line_iou = {'loss_line_iou': loss_line_iou}
+
+
+
+                    else:
+                        loss_line = {}
+                        loss_line_iou = {}
+                        if feature_logits is None or line_proposals is None:
+                            raise ValueError(
+                                "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                            )
+
+                        if feature_logits is not None:
+                            lines_probs, lines_scores = line_inference(feature_logits,line_proposals)
+                            for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
+                                r["lines"] = keypoint_prob
+                                r["lines_scores"] = kps
+
+
+
+
+                print(f'loss_line11111:{loss_line}')
+                losses.update(loss_line)
+                losses.update(loss_line_iou)
+                print(f'losses:{losses}')
         if self.has_point() and self.detect_point:
             print(f'roi_heads forward has_point()!!!!')
             # print(f'labels:{labels}')
@@ -1101,42 +1104,24 @@ class RoIHeads(nn.Module):
                 else:
                     pos_matched_idxs = None
 
-            feature_logits = self.point_forward1(features, image_shapes, point_proposals)
+            point_proposals_valid = self.check_proposals(point_proposals)
 
-            loss_point=None
+            if point_proposals_valid:
 
-            if self.training:
-
-                if targets is None or point_pos_matched_idxs is None:
-                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+                feature_logits = self.point_forward1(features, image_shapes, point_proposals)
 
-                gt_points = [t["points"] for t in targets if "points" in t]
+                loss_point=None
 
-                print(f'gt_points:{gt_points[0].shape}')
-                h, w = targets[0]["img_size"]
-                img_size = h
+                if self.training:
 
-                gt_points_tensor = torch.zeros(0, 0)
-                if len(gt_points) > 0:
-                    gt_points_tensor = torch.cat(gt_points)
-                    print(f'gt_points_tensor:{gt_points_tensor.shape}')
+                    if targets is None or point_pos_matched_idxs is None:
+                        raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
 
-                if gt_points_tensor.shape[0] > 0:
-                    print(f'start to compute point_loss')
-
-                    loss_point=compute_point_loss(feature_logits,point_proposals,gt_points,point_pos_matched_idxs)
-
-                if loss_point is None:
-                    print(f'loss_point is None111')
-                    loss_point = torch.tensor(0.0, device=device)
-
-                loss_point = {"loss_point": loss_point}
+                    gt_points = [t["points"] for t in targets if "points" in t]
 
-            else:
-                if targets is not None:
+                    print(f'gt_points:{gt_points[0].shape}')
                     h, w = targets[0]["img_size"]
                     img_size = h
-                    gt_points = [t["points"] for t in targets if "points" in t]
 
                     gt_points_tensor = torch.zeros(0, 0)
                     if len(gt_points) > 0:
@@ -1146,8 +1131,7 @@ class RoIHeads(nn.Module):
                     if gt_points_tensor.shape[0] > 0:
                         print(f'start to compute point_loss')
 
-                        loss_point = compute_point_loss(feature_logits, point_proposals, gt_points,
-                                                        point_pos_matched_idxs)
+                        loss_point=compute_point_loss(feature_logits,point_proposals,gt_points,point_pos_matched_idxs)
 
                     if loss_point is None:
                         print(f'loss_point is None111')
@@ -1155,25 +1139,48 @@ class RoIHeads(nn.Module):
 
                     loss_point = {"loss_point": loss_point}
 
+                else:
+                    if targets is not None:
+                        h, w = targets[0]["img_size"]
+                        img_size = h
+                        gt_points = [t["points"] for t in targets if "points" in t]
 
+                        gt_points_tensor = torch.zeros(0, 0)
+                        if len(gt_points) > 0:
+                            gt_points_tensor = torch.cat(gt_points)
+                            print(f'gt_points_tensor:{gt_points_tensor.shape}')
 
-                else:
-                    loss_point = {}
-                    if feature_logits is None or point_proposals is None:
-                        raise ValueError(
-                            "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
-                        )
+                        if gt_points_tensor.shape[0] > 0:
+                            print(f'start to compute point_loss')
+
+                            loss_point = compute_point_loss(feature_logits, point_proposals, gt_points,
+                                                            point_pos_matched_idxs)
 
-                    if feature_logits is not None:
+                        if loss_point is None:
+                            print(f'loss_point is None111')
+                            loss_point = torch.tensor(0.0, device=device)
 
-                        points_probs, points_scores = point_inference(feature_logits,point_proposals)
-                        for keypoint_prob, kps, r in zip(points_probs, points_scores, result):
-                            r["points"] = keypoint_prob
-                            r["points_scores"] = kps
+                        loss_point = {"loss_point": loss_point}
 
-            print(f'loss_point:{loss_point}')
-            losses.update(loss_point)
-            print(f'losses:{losses}')
+
+
+                    else:
+                        loss_point = {}
+                        if feature_logits is None or point_proposals is None:
+                            raise ValueError(
+                                "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                            )
+
+                        if feature_logits is not None:
+
+                            points_probs, points_scores = point_inference(feature_logits,point_proposals)
+                            for keypoint_prob, kps, r in zip(points_probs, points_scores, result):
+                                r["points"] = keypoint_prob
+                                r["points_scores"] = kps
+
+                print(f'loss_point:{loss_point}')
+                losses.update(loss_point)
+                print(f'losses:{losses}')
 
 
         if self.has_arc() and self.detect_arc:
@@ -1218,41 +1225,20 @@ class RoIHeads(nn.Module):
                 else:
                     arc_pos_matched_idxs = None
 
-            feature_logits = self.arc_forward1(features, image_shapes, arc_proposals)
-
-            loss_arc=None
+            arc_proposals_valid=self.check_proposals(arc_proposals)
 
-            if self.training:
-
-                if targets is None or arc_pos_matched_idxs is None:
-                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+            if arc_proposals_valid:
 
-                gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
 
-                print(f'gt_arcs:{gt_arcs[0].shape}')
-                h, w = targets[0]["img_size"]
-                img_size = h
+                feature_logits = self.arc_forward1(features, image_shapes, arc_proposals)
 
-                # gt_arcs_tensor = torch.zeros(0, 0)
-                # if len(gt_arcs) > 0:
-                #     gt_arcs_tensor = torch.cat(gt_arcs)
-                #     print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
-                #
-                # if gt_arcs_tensor.shape[0] > 0:
-                #     print(f'start to compute point_loss')
-                if len(gt_arcs) > 0 and feature_logits is not None:
-                    loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
+                loss_arc=None
 
-                if loss_arc is None:
-                    print(f'loss_arc is None111')
-                    loss_arc = torch.tensor(0.0, device=device)
+                if self.training:
 
-                loss_arc = {"loss_arc": loss_arc}
+                    if targets is None or arc_pos_matched_idxs is None:
+                        raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
 
-            else:
-                if targets is not None:
-                    h, w = targets[0]["img_size"]
-                    img_size = h
                     gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
 
                     print(f'gt_arcs:{gt_arcs[0].shape}')
@@ -1263,46 +1249,72 @@ class RoIHeads(nn.Module):
                     # if len(gt_arcs) > 0:
                     #     gt_arcs_tensor = torch.cat(gt_arcs)
                     #     print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
-
-                    # if gt_arcs_tensor.shape[0] > 0 and feature_logits is not None:
-                    #     print(f'start to compute arc_loss')
-
+                    #
+                    # if gt_arcs_tensor.shape[0] > 0:
+                    #     print(f'start to compute point_loss')
                     if len(gt_arcs) > 0 and feature_logits is not None:
-                        print(f'start to compute arc_loss')
                         loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
 
-
                     if loss_arc is None:
                         print(f'loss_arc is None111')
                         loss_arc = torch.tensor(0.0, device=device)
 
                     loss_arc = {"loss_arc": loss_arc}
 
+                else:
+                    if targets is not None:
+                        h, w = targets[0]["img_size"]
+                        img_size = h
+                        gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
 
+                        print(f'gt_arcs:{gt_arcs[0].shape}')
+                        h, w = targets[0]["img_size"]
+                        img_size = h
 
-                else:
-                    loss_arc = {}
-                    if feature_logits is None or arc_proposals is None:
-                        # raise ValueError(
-                        #     "both arc_feature_logits and arc_proposals should not be None when not in training mode"
-                        # )
+                        # gt_arcs_tensor = torch.zeros(0, 0)
+                        # if len(gt_arcs) > 0:
+                        #     gt_arcs_tensor = torch.cat(gt_arcs)
+                        #     print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
+
+                        # if gt_arcs_tensor.shape[0] > 0 and feature_logits is not None:
+                        #     print(f'start to compute arc_loss')
+
+                        if len(gt_arcs) > 0 and feature_logits is not None:
+                            print(f'start to compute arc_loss')
+                            loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
+
+
+                        if loss_arc is None:
+                            print(f'loss_arc is None111')
+                            loss_arc = torch.tensor(0.0, device=device)
+
+                        loss_arc = {"loss_arc": loss_arc}
 
-                        print(f'error :both arc_feature_logits and arc_proposals should not be None when not in training mode"')
-                        pass
 
-                    if feature_logits is not None and arc_proposals is not None:
 
-                        arcs_probs, arcs_scores, arcs_point = arc_inference(feature_logits,arc_proposals, th=0)
-                        for keypoint_prob, kps, kp, r in zip(arcs_probs, arcs_scores, arcs_point, result):
-                            # r["arcs"] = keypoint_prob
-                            r["arcs"] = feature_logits
-                            r["arcs_scores"] = kps
-                            r["arcs_point"] = feature_logits
+                    else:
+                        loss_arc = {}
+                        if feature_logits is None or arc_proposals is None:
+                            # raise ValueError(
+                            #     "both arc_feature_logits and arc_proposals should not be None when not in training mode"
+                            # )
 
+                            print(f'error :both arc_feature_logits and arc_proposals should not be None when not in training mode"')
+                            pass
 
-            # print(f'loss_point:{loss_point}')
-            losses.update(loss_arc)
-            print(f'losses:{losses}')
+                        if feature_logits is not None and arc_proposals is not None:
+
+                            arcs_probs, arcs_scores, arcs_point = arc_inference(feature_logits,arc_proposals, th=0)
+                            for keypoint_prob, kps, kp, r in zip(arcs_probs, arcs_scores, arcs_point, result):
+                                # r["arcs"] = keypoint_prob
+                                r["arcs"] = feature_logits
+                                r["arcs_scores"] = kps
+                                r["arcs_point"] = feature_logits
+
+
+                # print(f'loss_point:{loss_point}')
+                losses.update(loss_arc)
+                print(f'losses:{losses}')
 
         if self.has_circle and self.detect_circle:
             print(f'roi_heads forward has_circle()!!!!')
@@ -1345,50 +1357,30 @@ class RoIHeads(nn.Module):
                 else:
                     pos_matched_idxs = None
 
-            feature_logits = self.circle_forward1(features, image_shapes, circle_proposals)
-
-            loss_circle = None
-            loss_circle_extra=None
-
-            if self.training:
+            # circle_proposals_tensor=torch.cat(circle_proposals)
 
-                if targets is None or circle_pos_matched_idxs is None:
-                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+            circle_proposals_valid = self.check_proposals(circle_proposals)
 
-                gt_circles = [t["circles"] for t in targets if "circles" in t]
+            if  circle_proposals_valid:
 
-                print(f'gt_circle:{gt_circles[0].shape}')
-                h, w = targets[0]["img_size"]
-                img_size = h
 
-                gt_circles_tensor = torch.zeros(0, 0)
-                if len(gt_circles) > 0:
-                    gt_circles_tensor = torch.cat(gt_circles)
-                    print(f'gt_circles_tensor:{gt_circles_tensor.shape}')
 
-                if gt_circles_tensor.shape[0] > 0:
-                    print(f'start to compute circle_loss')
 
-                    loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
+                feature_logits = self.circle_forward1(features, image_shapes, circle_proposals)
 
-                    # loss_circle_extra=compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
+                loss_circle = None
+                loss_circle_extra=None
 
-                if loss_circle is None:
-                    print(f'loss_circle is None111')
-                    loss_circle = torch.tensor(0.0, device=device)
+                if self.training:
 
-                if loss_circle_extra is None:
-                    print(f'loss_circle_extra is None111')
-                    loss_circle_extra = torch.tensor(0.0, device=device)
+                    if targets is None or circle_pos_matched_idxs is None:
+                        raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
 
-                loss_circle = {"loss_circle": loss_circle}
-                loss_circle_extra = {"loss_circle_extra": loss_circle_extra}
+                    gt_circles = [t["circles"] for t in targets if "circles" in t]
 
-            else:
-                if targets is not None:
+                    print(f'gt_circle:{gt_circles[0].shape}')
                     h, w = targets[0]["img_size"]
                     img_size = h
-                    gt_circles = [t["circles"] for t in targets if "circles" in t]
 
                     gt_circles_tensor = torch.zeros(0, 0)
                     if len(gt_circles) > 0:
@@ -1398,10 +1390,9 @@ class RoIHeads(nn.Module):
                     if gt_circles_tensor.shape[0] > 0:
                         print(f'start to compute circle_loss')
 
-                        loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles,
-                                                        circle_pos_matched_idxs)
+                        loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
 
-                        # loss_circle_extra = compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles,circle_pos_matched_idxs)
+                        # loss_circle_extra=compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
 
                     if loss_circle is None:
                         print(f'loss_circle is None111')
@@ -1414,28 +1405,58 @@ class RoIHeads(nn.Module):
                     loss_circle = {"loss_circle": loss_circle}
                     loss_circle_extra = {"loss_circle_extra": loss_circle_extra}
 
+                else:
+                    if targets is not None:
+                        h, w = targets[0]["img_size"]
+                        img_size = h
+                        gt_circles = [t["circles"] for t in targets if "circles" in t]
 
+                        gt_circles_tensor = torch.zeros(0, 0)
+                        if len(gt_circles) > 0:
+                            gt_circles_tensor = torch.cat(gt_circles)
+                            print(f'gt_circles_tensor:{gt_circles_tensor.shape}')
 
-                else:
-                    loss_circle = {}
-                    loss_circle_extra = {}
-                    if feature_logits is None or circle_proposals is None:
-                        raise ValueError(
-                            "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
-                        )
+                        if gt_circles_tensor.shape[0] > 0:
+                            print(f'start to compute circle_loss')
+
+                            loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles,
+                                                            circle_pos_matched_idxs)
+
+                            # loss_circle_extra = compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles,circle_pos_matched_idxs)
+
+                        if loss_circle is None:
+                            print(f'loss_circle is None111')
+                            loss_circle = torch.tensor(0.0, device=device)
+
+                        if loss_circle_extra is None:
+                            print(f'loss_circle_extra is None111')
+                            loss_circle_extra = torch.tensor(0.0, device=device)
 
-                    if feature_logits is not None:
+                        loss_circle = {"loss_circle": loss_circle}
+                        loss_circle_extra = {"loss_circle_extra": loss_circle_extra}
 
-                        circles_probs, circles_scores = circle_inference(feature_logits, circle_proposals)
-                        for keypoint_prob, kps, r in zip(circles_probs, circles_scores, result):
-                            r["circles"] = keypoint_prob
-                            r["circles_scores"] = kps
 
-            print(f'loss_circle:{loss_circle}')
-            print(f'loss_circle_extra:{loss_circle_extra}')
-            losses.update(loss_circle)
-            losses.update(loss_circle_extra)
-            print(f'losses:{losses}')
+
+                    else:
+                        loss_circle = {}
+                        loss_circle_extra = {}
+                        if feature_logits is None or circle_proposals is None:
+                            raise ValueError(
+                                "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                            )
+
+                        if feature_logits is not None:
+
+                            circles_probs, circles_scores = circle_inference(feature_logits, circle_proposals)
+                            for keypoint_prob, kps, r in zip(circles_probs, circles_scores, result):
+                                r["circles"] = keypoint_prob
+                                r["circles_scores"] = kps
+
+                print(f'loss_circle:{loss_circle}')
+                print(f'loss_circle_extra:{loss_circle_extra}')
+                losses.update(loss_circle)
+                losses.update(loss_circle_extra)
+                print(f'losses:{losses}')
 
 
         if self.has_mask():
@@ -1527,6 +1548,14 @@ class RoIHeads(nn.Module):
 
         return result, losses
 
+    def check_proposals(self, proposals):
+        valid = True
+        for proposal in proposals:
+            # print(f'per circle_proposal:{circle_proposal.shape}')
+            if proposal.shape[0] == 0:
+                valid = False
+        return valid
+
     def line_forward1(self, features, image_shapes, line_proposals):
         print(f'line_proposals:{len(line_proposals)}')
         # cs_features= features['0']

+ 4 - 4
models/line_detect/train_demo.py

@@ -18,11 +18,11 @@ if __name__ == '__main__':
 
     # model=linedetect_resnet18_fpn()
     # model=linedetect_newresnet18fpn(num_points=4)
-    model=linedetect_newresnet50fpn(num_points=4)
-    # model = linedetect_newresnet101fpn(num_points=3)
-    # model = linedetect_newresnet152fpn(num_points=3)
+    # model=linedetect_newresnet50fpn(num_points=4)
+    # model = linedetect_newresnet101fpn(num_points=4)
+    # model = linedetect_newresnet152fpn(num_points=4)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
-    # model=linedetect_maxvitfpn()
+    model=linedetect_maxvitfpn()
     # model=linedetect_efficientnet(name='efficientnet_v2_l')
     # model=linedetect_high_maxvitfpn()