Kaynağa Gözat

Merge branch 'keypoint' of https://dev.lstznkj.com/DevLibs/MultiVisionModels into keypoint

# Conflicts:
#	models/line_detect/line_dataset.py
lstrlq 5 ay önce
ebeveyn
işleme
d650405319

+ 1 - 1
models/line_detect/line_detect.py

@@ -22,7 +22,7 @@ from libs.vision_libs.models.detection._utils import overwrite_eps
 from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, \
     BackboneWithFPN, resnet_fpn_backbone
 from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
-from .roi_heads import RoIHeads
+from .loi_heads import RoIHeads
 
 from .trainer import Trainer
 from ..base import backbone_factory

+ 42 - 0
models/line_detect/line_heads.py

@@ -0,0 +1,42 @@
+import torch
+from torch import nn
+
+class LineHeads(nn.Sequential):
+    def __init__(self, in_channels, layers):
+        d = []
+        next_feature = in_channels
+        for out_channels in layers:
+            d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
+            d.append(nn.ReLU(inplace=True))
+            next_feature = out_channels
+        super().__init__(*d)
+        for m in self.children():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                nn.init.constant_(m.bias, 0)
+
+
+class LinePredictor(nn.Module):
+    def __init__(self, in_channels, out_channels=1 ):
+        super().__init__()
+        input_features = in_channels
+        deconv_kernel = 4
+        self.kps_score_lowres = nn.ConvTranspose2d(
+            input_features,
+            out_channels,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+        self.out_channels = out_channels
+
+    def forward(self, x):
+        print(f'before kps_score_lowres x:{x.shape}')
+        x = self.kps_score_lowres(x)
+        print(f'kps_score_lowres x:{x.shape}')
+        return torch.nn.functional.interpolate(
+            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        )

+ 106 - 6
models/line_detect/roi_heads.py → models/line_detect/loi_heads.py

@@ -543,21 +543,104 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     line_loss=F.cross_entropy(line_logits,gs_heatmaps)
 
     return line_loss
+def is_collinear(p1, p2, q1, q2, eps=1e-6):
+    v1 = p2 - p1
+    v2 = q1 - p1
+    cross_z = v1[0] * v2[1] - v1[1] * v2[0]
+    return abs(cross_z) < eps
+
+
+def segment_intersection_length(line1, line2):
+    p1, p2 = line1
+    q1, q2 = line2
+
+    if not is_collinear(p1, p2, q1, q2):
+        return 0.0
+
+    dir_vec = p2 - p1
+    if torch.norm(dir_vec) == 0:
+        return 0.0
+
+    def project(point):
+        return torch.dot(point - p1, dir_vec)
+
+    t_p1 = 0.0
+    t_p2 = 1.0
+    t_q1 = project(q1)
+    t_q2 = project(q2)
+
+    t_min = max(t_p1, min(t_q1, t_q2))
+    t_max = min(t_p2, max(t_q1, t_q2))
+
+    if t_min >= t_max:
+        return 0.0
+
+    length = torch.norm(dir_vec) * (t_max - t_min)
+    return length.item()
+
+
+def line_iou(pred_line, target_line):
+    pred_line_coords = pred_line[:, :2]
+    target_line_coords = target_line[:, :2]
+
+    l1_len = torch.norm(pred_line_coords[1] - pred_line_coords[0])
+    l2_len = torch.norm(target_line_coords[1] - target_line_coords[0])
+
+    inter_len = segment_intersection_length(pred_line_coords, target_line_coords)
+    union_len = l1_len + l2_len - inter_len
+
+    if union_len <= 0:
+        return 0.0
+
+    return inter_len / union_len
+
+
+def line_iou_loss(x, boxes,gt_lines,matched_idx):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    points_probs = []
+    points_scores = []
+    losses = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb,gt_line,mid in zip(x2, boxes,gt_lines,matched_idx):
+        p_prob, scores = heatmaps_to_lines(xx, bb)
+        points_probs.append(p_prob)
+        points_scores.append(scores)
+        gt_line_points=gt_line[mid]
+        print(f'gt_line_points:{gt_line_points.shape}')
+        # 匹配预测线段和真实线段(例如匈牙利匹配)
+        # 这里假设一对一匹配
+        pred_lines = p_prob  # shape: (num_pred_lines, 2, 2)
+        print(f'pred_lines:{pred_lines.shape}')
+
+        for j in range(min(len(pred_lines), len(gt_line_points))):
+            iou = line_iou(pred_lines[j], gt_line_points[j])
+            losses.append(1.0 - iou)  # 损失为 1 - IoU
+
+    total_loss = torch.mean(torch.stack(losses)) if losses else None
+
+    return total_loss
 
 def line_inference(x, boxes):
     # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
-    kp_probs = []
-    kp_scores = []
+    points_probs = []
+    points_scores = []
 
     boxes_per_image = [box.size(0) for box in boxes]
     x2 = x.split(boxes_per_image, dim=0)
 
     for xx, bb in zip(x2, boxes):
-        kp_prob, scores = heatmaps_to_lines(xx, bb)
-        kp_probs.append(kp_prob)
-        kp_scores.append(scores)
+        p_prob, scores = heatmaps_to_lines(xx, bb)
+        points_probs.append(p_prob)
+        points_scores.append(scores)
 
-    return kp_probs, kp_scores
+
+
+
+
+    return points_probs, points_scores
 
 def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
     # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
@@ -1117,7 +1200,11 @@ class RoIHeads(nn.Module):
                     pos_matched_idxs = None
 
             print(f'line_proposals:{len(line_proposals)}')
+
+
             line_features = self.line_roi_pool(features, line_proposals, image_shapes)
+
+
             print(f'line_features from line_roi_pool:{line_features.shape}')
             line_features = self.line_head(line_features)
             print(f'line_features from line_head:{line_features.shape}')
@@ -1126,6 +1213,7 @@ class RoIHeads(nn.Module):
             print(f'line_logits:{line_logits.shape}')
 
             loss_line = {}
+            loss_line_iou={}
             if self.training:
                 if targets is None or pos_matched_idxs is None:
                     raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
@@ -1134,7 +1222,11 @@ class RoIHeads(nn.Module):
                 rcnn_loss_line = lines_point_pair_loss(
                     line_logits, line_proposals, gt_lines, pos_matched_idxs
                 )
+                iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs)
+
                 loss_line = {"loss_line": rcnn_loss_line}
+                loss_line_iou = {'loss_line_iou': iou_loss}
+
             else:
                 if targets is not None:
                     gt_lines = [t["lines"] for t in targets]
@@ -1142,6 +1234,11 @@ class RoIHeads(nn.Module):
                         line_logits, line_proposals, gt_lines, pos_matched_idxs
                     )
                     loss_line = {"loss_line": rcnn_loss_lines}
+
+                    iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs)
+                    loss_line_iou={'loss_line_iou':iou_loss}
+
+
                 else:
                     if line_logits is None or line_proposals is None:
                         raise ValueError(
@@ -1149,10 +1246,13 @@ class RoIHeads(nn.Module):
                         )
 
                     lines_probs, kp_scores = line_inference(line_logits, line_proposals)
+
                     for keypoint_prob, kps, r in zip(lines_probs, kp_scores, result):
                         r["lines"] = keypoint_prob
                         r["liness_scores"] = kps
+
             losses.update(loss_line)
+            losses.update(loss_line_iou)