RenLiqiang 5 месяцев назад
Родитель
Сommit
4c98a328dc

+ 5 - 0
libs/vision_libs/models/detection/transform.py

@@ -217,6 +217,11 @@ class GeneralizedRCNNTransform(nn.Module):
             points = target["points"]
             points = resize_keypoints(points, (h, w), image.shape[-2:])
             target["points"] = points
+
+        if "arc_mask" in target:
+            arc_mask = target["arc_mask"]
+            arc_mask = resize_keypoints(arc_mask, (h, w), image.shape[-2:])
+            target["arc_mask"] = arc_mask
         return image, target
 
     # _onnx_batch_images() is an implementation of

+ 43 - 0
models/line_detect/heads/arc_heads.py

@@ -0,0 +1,43 @@
+import torch
+from torch import nn
+
+class ArcHeads(nn.Sequential):
+    def __init__(self, in_channels, layers):
+        d = []
+        next_feature = in_channels
+        for out_channels in layers:
+            d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
+            d.append(nn.ReLU(inplace=True))
+            next_feature = out_channels
+        super().__init__(*d)
+        for m in self.children():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                nn.init.constant_(m.bias, 0)
+
+
+class ArcPredictor(nn.Module):
+    def __init__(self, in_channels, out_channels=3 ):
+        super().__init__()
+        input_features = in_channels
+        deconv_kernel = 4
+        self.kps_score_lowres = nn.ConvTranspose2d(
+            input_features,
+            out_channels,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+        self.out_channels = out_channels
+
+    def forward(self, x):
+        print(f'before kps_score_lowres x:{x.shape}')
+        x = self.kps_score_lowres(x)
+        print(f'kps_score_lowres x:{x.shape}')
+        return x
+        # return torch.nn.functional.interpolate(
+        #     x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        # )

+ 71 - 0
models/line_detect/heads/head_losses.py

@@ -418,6 +418,77 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
 
     return line_loss
 
+def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
+    print(f'compute_arc_loss:{feature_logits.shape}')
+    N, K, H, W = feature_logits.shape
+    len_proposals = len(proposals)
+
+    empty_count = 0
+    non_empty_count = 0
+
+    for prop in proposals:
+        if prop.shape[0] == 0:
+            empty_count += 1
+        else:
+            non_empty_count += 1
+
+    print(f"Empty proposals count: {empty_count}")
+    print(f"Non-empty proposals count: {non_empty_count}")
+
+    print(f'starte to compute_point_loss')
+    print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}')
+    if H != W:
+        raise ValueError(
+            f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+
+    gs_heatmaps = []
+    # print(f'point_matched_idxs:{point_matched_idxs}')
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs):
+        print(f'proposals_per_image:{proposals_per_image.shape}')
+        kp = gt_kp_in_image[midx]
+        # print(f'gt_kp_in_image:{gt_kp_in_image}')
+        gs_heatmaps_per_img = arc_points_to_heatmap(kp, proposals_per_image, discretization_size)
+        gs_heatmaps.append(gs_heatmaps_per_img)
+
+    gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
+    print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.squeeze(1).shape}')
+
+    line_logits = feature_logits[:, 0]
+    print(f'single_point_logits:{line_logits.shape}')
+
+    line_loss = F.cross_entropy(line_logits, gs_heatmaps)
+
+    return line_loss
+
+def arc_points_to_heatmap(keypoints, rois, heatmap_size):
+    print(f'rois:{rois.shape}')
+    print(f'heatmap_size:{heatmap_size}')
+
+    print(f'keypoints.shape:{keypoints.shape}')
+    # batch_size, num_keypoints, _ = keypoints.shape
+
+    x = keypoints[..., 0].unsqueeze(1)
+    y = keypoints[..., 1].unsqueeze(1)
+
+    gs = generate_gaussian_heatmaps(x, y, num_points=10, heatmap_size=heatmap_size, sigma=1.0)
+    # show_heatmap(gs[0],'target')
+    all_roi_heatmap = []
+    for roi, heatmap in zip(rois, gs):
+        show_heatmap(heatmap, 'target')
+        print(f'heatmap:{heatmap.shape}')
+        heatmap = heatmap.unsqueeze(0)
+        x1, y1, x2, y2 = map(int, roi)
+        roi_heatmap = torch.zeros_like(heatmap)
+        roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
+        show_heatmap(roi_heatmap[0],'roi_heatmap')
+        all_roi_heatmap.append(roi_heatmap)
+
+    all_roi_heatmap = torch.cat(all_roi_heatmap)
+    print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
+
+    return all_roi_heatmap
 
 def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
     # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor

+ 28 - 3
models/line_detect/line_dataset.py

@@ -80,7 +80,7 @@ class LineDataset(BaseDataset):
         target = {}
 
         target["image_id"] = torch.tensor(item)
-        boxes, lines, points, labels = get_boxes_lines(objs, shape)
+        boxes, lines, points, arc_mask,labels = get_boxes_lines(objs, shape)
 
         if points is not None:
             target["points"]=points
@@ -89,6 +89,12 @@ class LineDataset(BaseDataset):
             lines = torch.cat((lines, a), dim=1)
             target["lines"] = lines.to(torch.float32).view(-1, 2, 3)
 
+        if arc_mask is not None:
+            target['arc_mask']=arc_mask
+            print(f'arc_mask dataset')
+        else:
+            print(f'not arc_mask dataset')
+
         target["boxes"]=boxes
         target["labels"]=labels
         # target["boxes"], lines,target["points"], target["labels"] = get_boxes_lines(objs,shape)
@@ -155,6 +161,7 @@ def get_boxes_lines(objs,shape):
     h,w=shape
     line_point_pairs = []
     points=[]
+    line_mask=[]
 
     for obj in objs:
         # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
@@ -188,7 +195,20 @@ def get_boxes_lines(objs,shape):
 
 
 
-        elif label =='arc':
+
+        elif label == 'arc':
+
+            line_mask.append(obj['points'])
+
+            xmin = obj['xmin']
+
+            xmax = obj['xmax']
+
+            ymin = obj['ymin']
+
+            ymax = obj['ymax']
+
+            boxes.append([xmin, ymin, xmax, ymax])
 
             labels.append(torch.tensor(3))
 
@@ -207,7 +227,12 @@ def get_boxes_lines(objs,shape):
         line_point_pairs=torch.tensor(line_point_pairs)
 
     # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}')
-    return boxes,line_point_pairs,points,labels
+
+    if len(line_mask)==0:
+        line_mask=None
+    else:
+        line_mask=torch.tensor(line_mask)
+    return boxes,line_point_pairs,points,line_mask, labels
 
 if __name__ == '__main__':
     path=r"\\192.168.50.222\share\rlq\datasets\Dataset0709_"

+ 18 - 4
models/line_detect/line_detect.py

@@ -22,6 +22,7 @@ from libs.vision_libs.models.detection._utils import overwrite_eps
 from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, \
     BackboneWithFPN, resnet_fpn_backbone
 from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
+from .heads.arc_heads import ArcHeads, ArcPredictor
 from .heads.line_heads import LinePredictor
 from .heads.point_heads import PointHeads, PointPredictor
 from .loi_heads import RoIHeads
@@ -188,6 +189,14 @@ class LineDetect(BaseDetectionNet):
         #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
             point_predictor = PointPredictor(in_channels=128)
 
+        if detect_arc and arc_head is None:
+            layers = tuple(num_points for _ in range(8))
+            arc_head=ArcHeads(8,layers)
+        if detect_arc and arc_predictor is None:
+            layers = tuple(num_points for _ in range(8))
+            arc_predictor=ArcPredictor(in_channels=128)
+
+
 
         self.roi_heads.line_roi_pool = line_roi_pool
         self.roi_heads.line_head = line_head
@@ -197,6 +206,10 @@ class LineDetect(BaseDetectionNet):
         self.roi_heads.point_head = point_head
         self.roi_heads.point_predictor = point_predictor
 
+        self.roi_heads.arc_roi_pool = arc_roi_pool
+        self.roi_heads.arc_head = arc_head
+        self.roi_heads.arc_predictor = arc_predictor
+
     def start_train(self, cfg):
         # cfg = read_yaml(cfg)
         self.trainer = Trainer()
@@ -428,11 +441,11 @@ def linedetect_maxvitfpn(
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     if num_classes is None:
-        num_classes = 3
+        num_classes = 4
     if num_points is None:
         num_points = 3
 
-    size=224*2
+    size=224*4
 
     maxvit = MaxVitBackbone(input_size=(size,size))
     # print(maxvit.named_children())
@@ -460,11 +473,12 @@ def linedetect_maxvitfpn(
         backbone=backbone_with_fpn,
         min_size=size,
         max_size=size,
-        num_classes=3,  # COCO 数据集有 91 类
+        num_classes=num_classes,  # COCO 数据集有 91 类
         rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input),
         box_roi_pool=roi_pooler,
-        detect_line=True,
+        detect_line=False,
         detect_point=False,
+        detect_arc=True,
     )
     return model
 

+ 165 - 1
models/line_detect/loi_heads.py

@@ -13,7 +13,7 @@ import libs.vision_libs.models.detection._utils as det_utils
 from collections import OrderedDict
 
 from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
-    lines_point_pair_loss, features_align, line_inference
+    lines_point_pair_loss, features_align, line_inference, compute_arc_loss
 
 
 def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
@@ -531,6 +531,11 @@ class RoIHeads(nn.Module):
             point_head=None,
             point_predictor=None,
 
+            # arc parameters
+            arc_roi_pool=None,
+            arc_head=None,
+            arc_predictor=None,
+
             # Mask
             mask_roi_pool=None,
             mask_head=None,
@@ -571,6 +576,10 @@ class RoIHeads(nn.Module):
         self.point_head = point_head
         self.point_predictor = point_predictor
 
+        self.arc_roi_pool = arc_roi_pool
+        self.arc_head = arc_head
+        self.arc_predictor = arc_predictor
+
 
 
         self.mask_roi_pool = mask_roi_pool
@@ -627,6 +636,15 @@ class RoIHeads(nn.Module):
         #     return False
         return True
 
+    def has_arc(self):
+        # if self.line_roi_pool is None:
+        #     return False
+        if self.arc_head is None:
+            return False
+        # if self.line_predictor is None:
+        #     return False
+        return True
+
     def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
         # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
         matched_idxs = []
@@ -1137,6 +1155,128 @@ class RoIHeads(nn.Module):
             losses.update(loss_point)
             print(f'losses:{losses}')
 
+
+        if self.has_arc() and self.detect_arc:
+            print(f'roi_heads forward has_arc()!!!!')
+            # print(f'labels:{labels}')
+            arc_proposals = [p["boxes"] for p in result]
+            print(f'boxes_proposals:{len(arc_proposals)}')
+
+            # if line_proposals is None or len(line_proposals) == 0:
+            #     # 返回空特征或者跳过该部分计算
+            #     return torch.empty(0, C, H, W).to(features['0'].device)
+
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                print(f'num_images:{num_images}')
+                arc_proposals = []
+                arc_pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+                for img_id in range(num_images):
+                    arc_pos=torch.where(labels[img_id] ==3)[0]
+                    arc_proposals.append(proposals[img_id][arc_pos])
+                    arc_pos_matched_idxs.append(matched_idxs[img_id][arc_pos])
+            else:
+                if targets is not None:
+
+                    num_images = len(proposals)
+                    arc_proposals = []
+
+                    arc_pos_matched_idxs = []
+                    print(f'val num_images:{num_images}')
+                    if matched_idxs is None:
+                        raise ValueError("if in trainning, matched_idxs should not be None")
+
+                    for img_id in range(num_images):
+                        arc_pos = torch.where(labels[img_id] == 3)[0]
+                        arc_proposals.append(proposals[img_id][arc_pos])
+                        arc_pos_matched_idxs.append(matched_idxs[img_id][arc_pos])
+
+                else:
+                    pos_matched_idxs = None
+
+            feature_logits = self.arc_forward1(features, image_shapes, arc_proposals)
+
+            loss_arc=None
+
+            if self.training:
+
+                if targets is None or arc_pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
+
+                print(f'gt_arcs:{gt_arcs[0].shape}')
+                h, w = targets[0]["img_size"]
+                img_size = h
+
+                gt_arcs_tensor = torch.zeros(0, 0)
+                if len(gt_arcs) > 0:
+                    gt_arcs_tensor = torch.cat(gt_arcs)
+                    print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
+
+                if gt_arcs_tensor.shape[0] > 0:
+                    print(f'start to compute point_loss')
+
+                    loss_arc=compute_arc_loss(feature_logits,arc_proposals,gt_arcs,arc_pos_matched_idxs)
+
+                if loss_arc is None:
+                    print(f'loss_arc is None111')
+                    loss_arc = torch.tensor(0.0, device=device)
+
+                loss_arc = {"loss_arc": loss_arc}
+
+            else:
+                if targets is not None:
+                    h, w = targets[0]["img_size"]
+                    img_size = h
+                    gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
+
+                    print(f'gt_arcs:{gt_arcs[0].shape}')
+                    h, w = targets[0]["img_size"]
+                    img_size = h
+
+                    gt_arcs_tensor = torch.zeros(0, 0)
+                    if len(gt_arcs) > 0:
+                        gt_arcs_tensor = torch.cat(gt_arcs)
+                        print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
+
+                    if gt_arcs_tensor.shape[0] > 0:
+                        print(f'start to compute point_loss')
+
+                        loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
+
+                    if loss_arc is None:
+                        print(f'loss_arc is None111')
+                        loss_arc = torch.tensor(0.0, device=device)
+
+                    loss_arc = {"loss_arc": loss_arc}
+
+
+
+                else:
+                    loss_arc = {}
+                    if feature_logits is None or arc_proposals is None:
+                        raise ValueError(
+                            "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                        )
+
+                    if feature_logits is not None:
+
+                        arcs_probs, arcs_scores = arc_inference(feature_logits,arc_proposals)
+                        for keypoint_prob, kps, r in zip(arcs_probs, arcs_scores, result):
+                            r["arcs"] = keypoint_prob
+                            r["arcs_scores"] = kps
+
+            print(f'loss_point:{loss_point}')
+            losses.update(loss_point)
+            print(f'losses:{losses}')
+
+
+
+
         if self.has_mask():
             mask_proposals = [p["boxes"] for p in result]
             if self.training:
@@ -1312,3 +1452,27 @@ class RoIHeads(nn.Module):
         if roi_features is not None:
             print(f'roi_features from align:{roi_features.shape}')
         return roi_features
+
+
+    def arc_forward1(self, features, image_shapes, proposals):
+        print(f'point_proposals:{len(proposals)}')
+        # cs_features= features['0']
+        print(f'features-0:{features['0'].shape}')
+        # cs_features = self.channel_compress(features['0'])
+        cs_features=features['0']
+        # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
+        #
+        # if len(filtered_proposals) > 0:
+        #     filtered_proposals_tensor = torch.cat(filtered_proposals)
+        #     print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
+        #     proposals=filtered_proposals
+        # point_proposals_tensor = torch.cat(proposals)
+        # print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
+
+        feature_logits = self.arc_predictor(cs_features)
+        print(f'feature_logits from line_head:{feature_logits.shape}')
+
+        roi_features = features_align(feature_logits, proposals, image_shapes)
+        if roi_features is not None:
+            print(f'roi_features from align:{roi_features.shape}')
+        return roi_features

+ 4 - 4
models/line_detect/train.yaml

@@ -1,8 +1,8 @@
 io:
   logdir: train_results
-  datadir: /data/share/rlq/datasets/250718caisegangban
+#  datadir: /data/share/rlq/datasets/250718caisegangban
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
-#  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
+  datadir: \\192.168.50.222/share/zyh/arc/a_dataset
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
@@ -16,8 +16,8 @@ train_params:
   num_workers: 8
   batch_size: 2
   max_epoch: 8000000
-  augmentation: True
-#  augmentation: False
+#  augmentation: True
+  augmentation: False
   optim:
     name: Adam
     lr: 4.0e-4

+ 1 - 1
models/line_detect/train_demo.py

@@ -23,6 +23,6 @@ if __name__ == '__main__':
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     model=linedetect_maxvitfpn()
     # model=linedetect_high_maxvitfpn()
-    model.load_weights(r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250718_151833/weights/best_val.pth')
+    # model.load_weights(r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250718_151833/weights/best_val.pth')
     # model=linedetect_swin_transformer_fpn(type='t')
     model.start_train(cfg='train.yaml')