Jelajahi Sumber

修改分离line_head 和point_head

RenLiqiang 5 bulan lalu
induk
melakukan
b8ae0774f3

+ 4 - 40
models/line_detect/line_detect.py

@@ -22,6 +22,7 @@ from libs.vision_libs.models.detection._utils import overwrite_eps
 from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, \
     BackboneWithFPN, resnet_fpn_backbone
 from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
+from .heads.line_heads import LinePredictor
 from .loi_heads import RoIHeads
 
 from .trainer import Trainer
@@ -39,6 +40,9 @@ __all__ = [
     "linedetect_resnet50_fpn",
 ]
 
+from ..line_net.line_detect import LineHeads
+
+
 def _default_anchorgen():
     anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
     aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
@@ -272,46 +276,6 @@ class ObjectionPredictor(nn.Module):
 
         return scores, bbox_deltas
 
-class LineHeads(nn.Sequential):
-    def __init__(self, in_channels, layers):
-        d = []
-        next_feature = in_channels
-        for out_channels in layers:
-            d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
-            d.append(nn.ReLU(inplace=True))
-            next_feature = out_channels
-        super().__init__(*d)
-        for m in self.children():
-            if isinstance(m, nn.Conv2d):
-                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
-                nn.init.constant_(m.bias, 0)
-
-class LinePredictor(nn.Module):
-    def __init__(self, in_channels, out_channels=3 ):
-        super().__init__()
-        input_features = in_channels
-        deconv_kernel = 4
-        self.kps_score_lowres = nn.ConvTranspose2d(
-            input_features,
-            out_channels,
-            deconv_kernel,
-            stride=2,
-            padding=deconv_kernel // 2 - 1,
-        )
-        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
-        nn.init.constant_(self.kps_score_lowres.bias, 0)
-        self.up_scale = 2
-        self.out_channels = out_channels
-
-    def forward(self, x):
-        print(f'before kps_score_lowres x:{x.shape}')
-        x = self.kps_score_lowres(x)
-        print(f'kps_score_lowres x:{x.shape}')
-        return  x
-        # return torch.nn.functional.interpolate(
-        #     x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
-        # )
-
 
 def linedetect_newresnet18fpn(
         *,

+ 22 - 98
models/line_detect/loi_heads.py

@@ -859,13 +859,13 @@ class RoIHeads(nn.Module):
                     pos = torch.where(labels[img_id] > 0)[0]
 
                     line_pos=torch.where(labels[img_id] ==2)[0]
-                    point_pos=torch.where(labels[img_id] ==1)[0]
+                    # point_pos=torch.where(labels[img_id] ==1)[0]
 
                     line_proposals.append(proposals[img_id][line_pos])
-                    point_proposals.append(proposals[img_id][point_pos])
+                    # point_proposals.append(proposals[img_id][point_pos])
 
                     line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
-                    point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
+                    # point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
 
                     # pos_matched_idxs.append(matched_idxs[img_id][pos])
             else:
@@ -874,47 +874,33 @@ class RoIHeads(nn.Module):
                     pos_matched_idxs = []
                     num_images = len(proposals)
                     line_proposals = []
-                    point_proposals=[]
-                    arc_proposals=[]
+
 
                     line_pos_matched_idxs = []
-                    point_pos_matched_idxs = []
                     print(f'val num_images:{num_images}')
                     if matched_idxs is None:
                         raise ValueError("if in trainning, matched_idxs should not be None")
 
                     for img_id in range(num_images):
-                        pos = torch.where(labels[img_id] > 0)[0]
-                        # line_proposals.append(proposals[img_id][pos])
-                        # pos_matched_idxs.append(matched_idxs[img_id][pos])
+                        # pos = torch.where(labels[img_id] > 0)[0]
 
                         line_pos = torch.where(labels[img_id] == 2)[0]
-                        point_pos = torch.where(labels[img_id] == 1)[0]
 
                         line_proposals.append(proposals[img_id][line_pos])
-                        point_proposals.append(proposals[img_id][point_pos])
 
                         line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
-                        point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
 
                 else:
                     pos_matched_idxs = None
 
             print(f'line_proposals:{len(line_proposals)}')
 
-            cs_features= features['0']
-
+            # cs_features= features['0']
+            cs_features = self.channel_compress(features['0'])
 
-            all_proposals=line_proposals+point_proposals
-            # print(f'point_proposals:{point_proposals}')
-            # print(f'all_proposals:{all_proposals}')
-            for p in point_proposals:
-                print(f'point_proposal:{p.shape}')
 
-            for ap in all_proposals:
-                print(f'ap_proposal:{ap.shape}')
 
-            filtered_proposals = [proposal for proposal in all_proposals if proposal.shape[0] > 0]
+            filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
             if len(filtered_proposals) > 0:
                 filtered_proposals_tensor=torch.cat(filtered_proposals)
                 print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
@@ -923,35 +909,17 @@ class RoIHeads(nn.Module):
 
             print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
 
+            roi_features = features_align(cs_features, line_proposals, image_shapes)
 
-            point_proposals_tensor=torch.cat(point_proposals)
-            print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
-
-            # line_features=None
-
-            feature_logits = self.line_predictor(cs_features)
-            print(f'feature_logits from line_predictor:{feature_logits.shape}')
-
-            point_features = features_align(feature_logits, point_proposals, image_shapes)
+            if roi_features is not None:
+                print(f'line_features from align:{roi_features.shape}')
 
+            feature_logits = self.line_head(roi_features)
+            print(f'feature_logits from line_head:{feature_logits.shape}')
 
 
-            line_features = features_align(feature_logits, line_proposals, image_shapes)
-
-            if line_features is not None:
-                print(f'line_features from align:{line_features.shape}')
-
-            if point_features is not None:
-                print(f'feature_logits  features_align:{point_features.shape}')
-            # feature_logits=point_features
-
-            # line_logits = combine_features
-            # print(f'line_logits:{line_logits.shape}')
-
             loss_line = None
             loss_line_iou =None
-            loss_point = None
-
 
             if self.training:
 
@@ -959,11 +927,6 @@ class RoIHeads(nn.Module):
                     raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
 
                 gt_lines = [t["lines"] for t in targets if "lines" in t]
-                gt_points = [t["points"] for t in targets if "points" in t]
-                #
-                # line_pos_matched_idxs = []
-                # point_pos_matched_idxs = []
-
 
 
 
@@ -972,29 +935,19 @@ class RoIHeads(nn.Module):
                 img_size = h
 
                 gt_lines_tensor=torch.zeros(0,0)
-                gt_points_tensor=torch.zeros(0,0)
                 if len(gt_lines)>0:
                     gt_lines_tensor = torch.cat(gt_lines)
                     print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
 
-                if len(gt_points)>0:
-                    gt_points_tensor = torch.cat(gt_points)
-                    print(f'gt_points_tensor:{gt_points_tensor.shape}')
-
-
 
                 if gt_lines_tensor.shape[0]>0 :
                     print(f'start to lines_point_pair_loss')
                     loss_line = lines_point_pair_loss(
-                        line_features, line_proposals, gt_lines, line_pos_matched_idxs
+                        feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
                     )
-                    loss_line_iou = line_iou_loss(line_features, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
+                    loss_line_iou = line_iou_loss(feature_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
+
 
-                if gt_points_tensor.shape[0]>0 :
-                    print(f'start to compute_point_loss ')
-                    loss_point = compute_point_loss(
-                        point_features, point_proposals, gt_points, point_pos_matched_idxs
-                    )
 
                 if  loss_line is None:
                     print(f'loss_line is None111')
@@ -1007,41 +960,27 @@ class RoIHeads(nn.Module):
                 loss_line = {"loss_line": loss_line}
                 loss_line_iou = {'loss_line_iou': loss_line_iou}
 
-                if loss_point is None:
-                    loss_point = {"loss_point": torch.tensor(0.0,device=feature_logits.device)}
-                else:
-                    loss_point = {"loss_point": loss_point}
-
             else:
                 if targets is not None:
                     h, w = targets[0]["img_size"]
                     img_size = h
                     gt_lines = [t["lines"] for t in targets if "lines" in t]
-                    gt_points = [t["points"] for t in targets if "points" in t]
 
                     gt_lines_tensor = torch.zeros(0, 0)
-                    gt_points_tensor = torch.zeros(0, 0)
                     if len(gt_lines)>0:
                         gt_lines_tensor = torch.cat(gt_lines)
-                    if len(gt_points)>0:
-                        gt_points_tensor = torch.cat(gt_points)
 
-                    # line_pos_matched_idxs = []
-                    # point_pos_matched_idxs = []
 
 
-                    if gt_lines_tensor.shape[0] > 0 and line_features is not None:
+                    if gt_lines_tensor.shape[0] > 0 and feature_logits is not None:
                         loss_line = lines_point_pair_loss(
-                            line_features, line_proposals, gt_lines, line_pos_matched_idxs
+                            feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
                         )
                         print(f'compute_line_loss:{loss_line}')
-                        loss_line_iou = line_iou_loss(line_features , line_proposals, gt_lines, line_pos_matched_idxs,
+                        loss_line_iou = line_iou_loss(feature_logits , line_proposals, gt_lines, line_pos_matched_idxs,
                                                       img_size)
 
-                    if gt_points_tensor.shape[0] > 0 and point_features is not None:
-                        loss_point = compute_point_loss(
-                            point_features, point_proposals, gt_points, point_pos_matched_idxs
-                        )
+
 
                     if  loss_line is None:
                         print(f'loss_line is None')
@@ -1051,18 +990,10 @@ class RoIHeads(nn.Module):
                         print(f'loss_line_iou is None')
                         loss_line_iou=torch.tensor(0.0,device=cs_features.device)
 
-                    # if  loss_point is None:
-                    #     print(f'loss_point is None')
-                    #     loss_point=torch.tensor(0.0,device=cs_features.device)
 
                     loss_line = {"loss_line": loss_line}
                     loss_line_iou = {'loss_line_iou': loss_line_iou}
                     
-                    if loss_point is None:
-                        loss_point = {"loss_point": torch.tensor(0.0, device=feature_logits.device)}
-                    else:
-                        loss_point = {"loss_point": loss_point}
-
 
 
                 else:
@@ -1074,25 +1005,18 @@ class RoIHeads(nn.Module):
                             "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
                         )
 
-                    if line_features is not None:
-                        lines_probs, lines_scores = line_inference(line_features,line_proposals)
+                    if feature_logits is not None:
+                        lines_probs, lines_scores = line_inference(feature_logits,line_proposals)
                         for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
                             r["lines"] = keypoint_prob
                             r["liness_scores"] = kps
 
-                    if point_features is not None:
-                        point_probs, points_scores=point_inference(point_features, point_proposals, )
-                        for  points, ps, r in zip(point_probs,points_scores, result):
-                            print(f'points_prob :{points.shape}')
 
-                            r["points"] = points
-                            r["points_scores"] = ps
 
 
             print(f'loss_line11111:{loss_line}')
             losses.update(loss_line)
             losses.update(loss_line_iou)
-            losses.update(loss_point)
             print(f'losses:{losses}')
 
         if self.has_mask():

+ 2 - 2
models/line_detect/train_demo.py

@@ -19,6 +19,6 @@ if __name__ == '__main__':
     # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet50fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
-    model=linedetect_maxvitfpn()
-    # model=linedetect_high_maxvitfpn()
+    # model=linedetect_maxvitfpn()
+    model=linedetect_high_maxvitfpn()
     model.start_train(cfg='train.yaml')