Przeglądaj źródła

融合line和point 的训练和推理

RenLiqiang 5 miesięcy temu
rodzic
commit
6a7559b676

+ 1 - 1
models/base/backbone_factory.py

@@ -262,7 +262,7 @@ def get_swin_transformer_fpn(type='t'):
         backbone,
         return_layers={'layer1': '0', 'layer3': '1', 'layer5': '2', 'layer7': '3'},
         in_channels_list=channels_list,
-        out_channels=256
+        out_channels=128
     )
     featmap_names = ['0', '1', '2', '3', 'pool']
     # print(f'featmap_names:{featmap_names}')

+ 4 - 3
models/line_detect/heads/point_heads.py

@@ -37,6 +37,7 @@ class PointPredictor(nn.Module):
         # print(f'before kps_score_lowres x:{x.shape}')
         x = self.kps_score_lowres(x)
         # print(f'kps_score_lowres x:{x.shape}')
-        return torch.nn.functional.interpolate(
-            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
-        )
+        return x
+        # return torch.nn.functional.interpolate(
+        #     x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        # )

+ 27 - 4
models/line_detect/line_detect.py

@@ -23,6 +23,7 @@ from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extract
     BackboneWithFPN, resnet_fpn_backbone
 from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
 from .heads.line_heads import LinePredictor
+from .heads.point_heads import PointHeads, PointPredictor
 from .loi_heads import RoIHeads
 
 from .trainer import Trainer
@@ -101,7 +102,11 @@ class LineDetect(BaseDetectionNet):
             arc_head=None,
             arc_predictor=None,
             num_points=3,
+            detect_point=True,
+            detect_line=True,
+            detect_arc=False,
             **kwargs,
+
     ):
 
         out_channels = backbone.out_channels
@@ -152,6 +157,9 @@ class LineDetect(BaseDetectionNet):
             box_score_thresh,
             box_nms_thresh,
             box_detections_per_img,
+            detect_point=detect_point,
+            detect_line=detect_line,
+            detect_arc=detect_arc,
         )
 
         if image_mean is None:
@@ -165,18 +173,30 @@ class LineDetect(BaseDetectionNet):
 
 
         if line_head is None:
-            keypoint_layers = tuple(num_points for _ in range(8))
-            line_head = LineHeads(8, keypoint_layers)
+            layers = tuple(num_points for _ in range(8))
+            line_head = LineHeads(8, layers)
 
         if line_predictor is None:
         #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
             line_predictor = LinePredictor(in_channels=256)
 
+        if point_head is None:
+            layers = tuple(num_points for _ in range(8))
+            point_head = PointHeads(8, layers)
+
+        if point_predictor is None:
+        #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
+            point_predictor = PointPredictor(in_channels=128)
+
 
         self.roi_heads.line_roi_pool = line_roi_pool
         self.roi_heads.line_head = line_head
         self.roi_heads.line_predictor = line_predictor
 
+        self.roi_heads.point_roi_pool = point_roi_pool
+        self.roi_heads.point_head = point_head
+        self.roi_heads.point_predictor = point_predictor
+
     def start_train(self, cfg):
         # cfg = read_yaml(cfg)
         self.trainer = Trainer()
@@ -354,11 +374,12 @@ def linedetect_newresnet50fpn(
     aspect_ratios = ((0.5, 1.0, 2.0),) * num_features
     # print(f'aspect_ratios:{aspect_ratios}')
 
-
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
 
     model = LineDetect(backbone, num_classes, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
 
+
+
     return model
 
 def linedetect_newresnet101fpn(
@@ -507,7 +528,9 @@ def linedetect_swin_transformer_fpn(
         max_size=size,
         num_classes=3,  # COCO 数据集有 91 类
         rpn_anchor_generator=anchor_generator,
-        box_roi_pool=roi_pooler
+        box_roi_pool=roi_pooler,
+        detect_line=False,
+        detect_point=True,
     )
     return model
 

+ 172 - 6
models/line_detect/loi_heads.py

@@ -525,6 +525,12 @@ class RoIHeads(nn.Module):
             line_roi_pool=None,
             line_head=None,
             line_predictor=None,
+
+            # point parameters
+            point_roi_pool=None,
+            point_head=None,
+            point_predictor=None,
+
             # Mask
             mask_roi_pool=None,
             mask_head=None,
@@ -532,6 +538,10 @@ class RoIHeads(nn.Module):
             keypoint_roi_pool=None,
             keypoint_head=None,
             keypoint_predictor=None,
+
+            detect_point=True,
+            detect_line=True,
+            detect_arc=False,
     ):
         super().__init__()
 
@@ -557,6 +567,12 @@ class RoIHeads(nn.Module):
         self.line_head = line_head
         self.line_predictor = line_predictor
 
+        self.point_roi_pool = point_roi_pool
+        self.point_head = point_head
+        self.point_predictor = point_predictor
+
+
+
         self.mask_roi_pool = mask_roi_pool
         self.mask_head = mask_head
         self.mask_predictor = mask_predictor
@@ -565,6 +581,10 @@ class RoIHeads(nn.Module):
         self.keypoint_head = keypoint_head
         self.keypoint_predictor = keypoint_predictor
 
+        self.detect_point =detect_point
+        self.detect_line =detect_line
+        self.detect_arc =detect_arc
+
         self.channel_compress = nn.Sequential(
             nn.Conv2d(256, 8, kernel_size=1),
             nn.BatchNorm2d(8),
@@ -598,6 +618,15 @@ class RoIHeads(nn.Module):
         #     return False
         return True
 
+    def has_point(self):
+        # if self.line_roi_pool is None:
+        #     return False
+        if self.point_head is None:
+            return False
+        # if self.line_predictor is None:
+        #     return False
+        return True
+
     def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
         # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
         matched_idxs = []
@@ -831,7 +860,7 @@ class RoIHeads(nn.Module):
                     }
                 )
 
-        if  self.has_line():
+        if  self.has_line() and self.detect_line:
             print(f'roi_heads forward has_line()!!!!')
             # print(f'labels:{labels}')
             line_proposals = [p["boxes"] for p in result]
@@ -894,7 +923,7 @@ class RoIHeads(nn.Module):
                 else:
                     pos_matched_idxs = None
 
-            feature_logits = self.head_forward3(features, image_shapes, line_proposals)
+            feature_logits = self.line_forward3(features, image_shapes, line_proposals)
 
             loss_line = None
             loss_line_iou =None
@@ -984,7 +1013,7 @@ class RoIHeads(nn.Module):
                         lines_probs, lines_scores = line_inference(feature_logits,line_proposals)
                         for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
                             r["lines"] = keypoint_prob
-                            r["liness_scores"] = kps
+                            r["lines_scores"] = kps
 
 
 
@@ -993,6 +1022,120 @@ class RoIHeads(nn.Module):
             losses.update(loss_line)
             losses.update(loss_line_iou)
             print(f'losses:{losses}')
+        if self.has_point() and self.detect_point:
+            print(f'roi_heads forward has_point()!!!!')
+            # print(f'labels:{labels}')
+            point_proposals = [p["boxes"] for p in result]
+            print(f'boxes_proposals:{len(point_proposals)}')
+
+            # if line_proposals is None or len(line_proposals) == 0:
+            #     # 返回空特征或者跳过该部分计算
+            #     return torch.empty(0, C, H, W).to(features['0'].device)
+
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                print(f'num_images:{num_images}')
+                point_proposals = []
+                point_pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+                for img_id in range(num_images):
+                    point_pos=torch.where(labels[img_id] ==1)[0]
+                    point_proposals.append(proposals[img_id][point_pos])
+                    point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
+            else:
+                if targets is not None:
+
+                    num_images = len(proposals)
+                    point_proposals = []
+
+                    point_pos_matched_idxs = []
+                    print(f'val num_images:{num_images}')
+                    if matched_idxs is None:
+                        raise ValueError("if in trainning, matched_idxs should not be None")
+
+                    for img_id in range(num_images):
+                        point_pos = torch.where(labels[img_id] == 1)[0]
+                        point_proposals.append(proposals[img_id][point_pos])
+                        point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
+
+                else:
+                    pos_matched_idxs = None
+
+            feature_logits = self.point_forward1(features, image_shapes, point_proposals)
+
+            loss_point=None
+
+            if self.training:
+
+                if targets is None or point_pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_points = [t["points"] for t in targets if "points" in t]
+
+                print(f'gt_points:{gt_points[0].shape}')
+                h, w = targets[0]["img_size"]
+                img_size = h
+
+                gt_points_tensor = torch.zeros(0, 0)
+                if len(gt_points) > 0:
+                    gt_points_tensor = torch.cat(gt_points)
+                    print(f'gt_points_tensor:{gt_points_tensor.shape}')
+
+                if gt_points_tensor.shape[0] > 0:
+                    print(f'start to compute point_loss')
+
+                    loss_point=compute_point_loss(feature_logits,point_proposals,gt_points,point_pos_matched_idxs)
+
+                if loss_point is None:
+                    print(f'loss_point is None111')
+                    loss_point = torch.tensor(0.0, device=device)
+
+                loss_point = {"loss_point": loss_point}
+
+            else:
+                if targets is not None:
+                    h, w = targets[0]["img_size"]
+                    img_size = h
+                    gt_points = [t["points"] for t in targets if "points" in t]
+
+                    gt_points_tensor = torch.zeros(0, 0)
+                    if len(gt_points) > 0:
+                        gt_points_tensor = torch.cat(gt_points)
+                        print(f'gt_points_tensor:{gt_points_tensor.shape}')
+
+                    if gt_points_tensor.shape[0] > 0:
+                        print(f'start to compute point_loss')
+
+                        loss_point = compute_point_loss(feature_logits, point_proposals, gt_points,
+                                                        point_pos_matched_idxs, img_size)
+
+                    if loss_point is None:
+                        print(f'loss_point is None111')
+                        loss_point = torch.tensor(0.0, device=device)
+
+                    loss_point = {"loss_point": loss_point}
+
+
+
+                else:
+                    loss_point = {}
+                    if feature_logits is None or point_proposals is None:
+                        raise ValueError(
+                            "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                        )
+
+                    if feature_logits is not None:
+
+                        points_probs, points_scores = point_inference(feature_logits,point_proposals)
+                        for keypoint_prob, kps, r in zip(points_probs, points_scores, result):
+                            r["points"] = keypoint_prob
+                            r["points_scores"] = kps
+
+            print(f'loss_point:{loss_point}')
+            losses.update(loss_point)
+            print(f'losses:{losses}')
 
         if self.has_mask():
             mask_proposals = [p["boxes"] for p in result]
@@ -1083,7 +1226,7 @@ class RoIHeads(nn.Module):
 
         return result, losses
 
-    def head_forward1(self, features, image_shapes, line_proposals):
+    def line_forward1(self, features, image_shapes, line_proposals):
         print(f'line_proposals:{len(line_proposals)}')
         # cs_features= features['0']
         print(f'features-0:{features['0'].shape}')
@@ -1101,7 +1244,7 @@ class RoIHeads(nn.Module):
         print(f'feature_logits from line_head:{feature_logits.shape}')
         return feature_logits
 
-    def head_forward2(self, features, image_shapes, line_proposals):
+    def line_forward2(self, features, image_shapes, line_proposals):
         print(f'line_proposals:{len(line_proposals)}')
         # cs_features= features['0']
         print(f'features-0:{features['0'].shape}')
@@ -1124,7 +1267,7 @@ class RoIHeads(nn.Module):
             print(f'roi_features from align:{roi_features.shape}')
         return roi_features
 
-    def head_forward3(self, features, image_shapes, line_proposals):
+    def line_forward3(self, features, image_shapes, line_proposals):
         print(f'line_proposals:{len(line_proposals)}')
         # cs_features= features['0']
         print(f'features-0:{features['0'].shape}')
@@ -1146,3 +1289,26 @@ class RoIHeads(nn.Module):
         if roi_features is not None:
             print(f'roi_features from align:{roi_features.shape}')
         return roi_features
+
+    def point_forward1(self, features, image_shapes, proposals):
+        print(f'point_proposals:{len(proposals)}')
+        # cs_features= features['0']
+        print(f'features-0:{features['0'].shape}')
+        # cs_features = self.channel_compress(features['0'])
+        cs_features=features['0']
+        filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
+
+        if len(filtered_proposals) > 0:
+            filtered_proposals_tensor = torch.cat(filtered_proposals)
+            print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
+            proposals=filtered_proposals
+        point_proposals_tensor = torch.cat(proposals)
+        print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
+
+        feature_logits = self.point_predictor(cs_features)
+        print(f'feature_logits from line_head:{feature_logits.shape}')
+
+        roi_features = features_align(cs_features, proposals, image_shapes)
+        if roi_features is not None:
+            print(f'roi_features from align:{roi_features.shape}')
+        return roi_features

+ 2 - 1
models/line_detect/train.yaml

@@ -1,7 +1,8 @@
 io:
   logdir: train_results
 #  datadir: /data/share/zjh/Dataset_correct_xanylabel
-  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
+#  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
+  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000

+ 5 - 5
models/line_detect/trainer.py

@@ -191,7 +191,7 @@ class Trainer(BaseTrainer):
 
 
 
-    def writer_predict_result(self, img, result, epoch, typ=1):
+    def writer_predict_result(self, img, result, epoch,):
         img = img.cpu().detach()
         im = img.permute(1, 2, 0)  # [512, 512, 3]
         self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
@@ -205,18 +205,18 @@ class Trainer(BaseTrainer):
         self.writer.add_image("z-obj", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
 
 
-        if typ==1 and 'points' in result:
+        if  'points' in result:
             keypoint_img = draw_keypoints(boxed_image, result['points'], colors='red', width=3)
 
             self.writer.add_image("z-output", keypoint_img, epoch)
         # print("lines shape:", result['lines'].shape)
 
 
-        if typ==2 and 'lines' in result:
+        if 'lines' in result:
             # 用自己写的函数画线段
             # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
             print(f"shape of linescore:{result['liness_scores'].shape}")
-            scores = result['liness_scores'].mean(dim=1)  # shape: [31]
+            scores = result['lines_scores'].mean(dim=1)  # shape: [31]
 
             line_image = draw_lines_with_scores((img * 255).to(torch.uint8),  result['lines'],scores, width=3, cmap='jet')
 
@@ -341,7 +341,7 @@ class Trainer(BaseTrainer):
                 # print(f'result:{result}')
                 t_end = time.time()
                 print(f'predict used:{t_end - t_start}')
-                self.writer_predict_result(img=imgs[0], result=result[0], typ=2, epoch=epoch)
+                self.writer_predict_result(img=imgs[0], result=result[0], epoch=epoch)
                 epoch_step+=1
 
         avg_loss = total_loss / len(data_loader)