zhaoyinghan 1 hónapja
szülő
commit
e4904275c4

+ 75 - 1
models/line_detect/heads/arc/arc_heads.py

@@ -40,4 +40,78 @@ class ArcPredictor(nn.Module):
         return x
         # return torch.nn.functional.interpolate(
         #     x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
-        # )
+        # )
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ArcEquationHead(nn.Module):
+    def __init__(self, num_outputs=7):
+        super().__init__()
+
+        # --------------------------------------------------
+        # Convolution layers - no fixed H,W assumptions
+        # Automatically downsamples using stride=2
+        # --------------------------------------------------
+        self.conv = nn.Sequential(
+            nn.Conv2d(1, 32, 3, stride=2, padding=1),
+            nn.ReLU(inplace=True),
+
+            nn.Conv2d(32, 64, 3, stride=2, padding=1),
+            nn.ReLU(inplace=True),
+
+            nn.Conv2d(64, 128, 3, stride=2, padding=1),
+            nn.ReLU(inplace=True),
+
+            nn.Conv2d(128, 256, 3, stride=2, padding=1),
+            nn.ReLU(inplace=True),
+        )
+
+        # --------------------------------------------------
+        # Global pooling ¡ú no H,W dependency
+        # --------------------------------------------------
+        self.gap = nn.AdaptiveAvgPool2d((1, 1))
+
+        # --------------------------------------------------
+        # MLP
+        # --------------------------------------------------
+        self.mlp = nn.Sequential(
+            nn.Linear(256, 256),
+            nn.ReLU(inplace=True),
+            nn.Linear(256, num_outputs)
+        )
+
+
+    def forward(self, feature_logits):
+        """
+        Args:
+            feature_logits: Tensor [N, 1, H, W]
+        """
+
+        # CNN
+        x = self.conv(feature_logits)
+
+        # Global pool
+        x = self.gap(x).view(x.size(0), -1)
+
+        # Predict params
+        arc_params = self.mlp(x)   # -> [N, 7]
+
+        N, _, H, W = feature_logits.shape
+
+        # --------------------------------------------
+        # Apply constraints
+        # --------------------------------------------
+        arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W   # cx
+        arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H   # cy
+
+        arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6        # long axis
+        arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6        # short axis
+
+        # angles 0~2¦Ð
+        arc_params[..., 4:7] = torch.sigmoid(arc_params[..., 4:7]) * (2 * 3.1415926535)
+
+        return arc_params

+ 6 - 2
models/line_detect/line_detect.py

@@ -9,7 +9,7 @@ from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
 from libs.vision_libs.ops import misc as misc_nn_ops, MultiScaleRoIAlign
 from libs.vision_libs.models.detection.backbone_utils import BackboneWithFPN, resnet_fpn_backbone
 from libs.vision_libs.models.detection.faster_rcnn import TwoMLPHead
-from models.line_detect.heads.arc.arc_heads import ArcHeads
+from models.line_detect.heads.arc.arc_heads import ArcHeads, ArcEquationHead
 from models.line_detect.heads.circle.circle_heads import CircleHeads, CirclePredictor
 from .heads.decoder import FPNDecoder
 from models.line_detect.heads.line.line_heads import LinePredictor
@@ -90,7 +90,7 @@ class LineDetect(BaseDetectionNet):
             ins_head=None,
             ins_predictor=None,
             circle_roi_pool=None,
-
+            arc_equation_head=None,
     # arc parameters
             arc_roi_pool=None,
             arc_head=None,
@@ -193,6 +193,7 @@ class LineDetect(BaseDetectionNet):
             arc_predictor=FPNDecoder(Bottleneck)
 
         if detect_ins and ins_head is None:
+
             layers = tuple(num_points for _ in range(8))
             ins_head = FPNDecoder(Bottleneck)
 
@@ -201,6 +202,7 @@ class LineDetect(BaseDetectionNet):
             # arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
             # circle_predictor = CirclePredictor(in_channels=256,out_channels=4)
             ins_predictor=ArcEquationPredictor()
+            arc_equation_head = ArcEquationHead()
 
 
 
@@ -220,6 +222,8 @@ class LineDetect(BaseDetectionNet):
         self.roi_heads.ins_roi_pool = circle_roi_pool
         self.roi_heads.ins_head = ins_head
         self.roi_heads.ins_predictor = ins_predictor
+        self.roi_heads.arc_equation_head = arc_equation_head
+
 
     def start_train(self, cfg):
         # cfg = read_yaml(cfg)

+ 160 - 4
models/line_detect/loi_heads.py

@@ -1372,9 +1372,12 @@ class RoIHeads(nn.Module):
 
                 print(f'features from backbone:{features['0'].shape}')
                 feature_logits = self.ins_forward1(features, image_shapes, ins_proposals)
+                arc_equation = self.arc_equation_head(feature_logits)  # [proposal和,7]
 
                 loss_ins = None
                 loss_ins_extra=None
+                loss_arc_equation = None
+                loss_arc_ends = None
 
                 if self.training:
 
@@ -1384,6 +1387,10 @@ class RoIHeads(nn.Module):
                     gt_inses = [t["circle_masks"] for t in targets if "circle_masks" in t]
                     gt_labels = [t["labels"] for t in targets]
 
+                    gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
+                    gt_mask_ends = [t["mask_ends"] for t in targets if "mask_ends" in t]
+                    gt_mask_params = [t["mask_params"] for t in targets if "mask_params" in t]
+
                     print(f'gt_ins:{gt_inses[0].shape}')
                     h, w = targets[0]["img_size"]
                     img_size = h
@@ -1396,9 +1403,21 @@ class RoIHeads(nn.Module):
                     if gt_ins_tensor.shape[0] > 0:
                         print(f'start to compute circle_loss')
 
-                        loss_ins = compute_ins_loss(feature_logits, ins_proposals, gt_inses, ins_pos_matched_idxs)
-
-                        # loss_ins_extra=compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
+                        loss_ins = compute_ins_loss(feature_logits, ins_proposals, gt_inses,ins_pos_matched_idxs)
+                        total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,
+                                                                                                 ins_proposals,
+                                                                                                 gt_mask_ends,
+                                                                                                 gt_mask_params,
+                                                                                                 ins_pos_matched_idxs,
+                                                                                                 labels)
+                        loss_arc_ends = loss_arc_ends
+                    if loss_arc_equation is None:
+                        print(f'loss_arc_equation is None')
+                        loss_arc_equation = torch.tensor(0.0, device=device)
+
+                    if loss_arc_ends is None:
+                        print(f'loss_arc_ends is None')
+                        loss_arc_ends = torch.tensor(0.0, device=device)
 
                     if loss_ins is None:
                         print(f'loss_ins is None111')
@@ -1408,7 +1427,7 @@ class RoIHeads(nn.Module):
                         print(f'loss_ins_extra is None111')
                         loss_ins_extra = torch.tensor(0.0, device=device)
 
-                    loss_ins = {"loss_ins": loss_ins}
+                    loss_ins = {"loss_ins": loss_ins,"loss_arc_equation": loss_arc_equation,"loss_arc_ends": loss_arc_ends}
                     loss_ins_extra = {"loss_ins_extra": loss_ins_extra}
 
                 else:
@@ -1715,3 +1734,140 @@ class RoIHeads(nn.Module):
         if roi_features is not None:
             print(f'roi_features from align:{roi_features.shape}')
         return roi_features
+
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+
+
+def compute_arc_equation_loss(arc_equation, proposals, gt_mask_ends, gt_mask_params, arc_pos_matched_idxs,
+                              gt_labels_all):
+    """
+    Compute loss between predicted arc equations and ground truth.
+
+    Args:
+        arc_equation: list of length B, each Tensor (N_i, 7)
+        gt_mask_ends: GT arc end masks (for angle calculation)
+        gt_mask_params: list of length B, each numpy array (num_gt, 5)
+        arc_pos_matched_idxs: list of length B, each Tensor of indices matching predictions to GT
+        gt_labels_all: list of length B, GT labels
+    """
+    len_proposals = len(proposals)  # batch
+    device = arc_equation[0].device
+    print(
+        f'compute_arc_equation_loss line_logits.shape:{arc_equation.shape},len_proposals:{len_proposals},line_matched_idxs:{arc_pos_matched_idxs}')
+    print(f'gt_mask_ends:{gt_mask_ends}, gt_mask_params:{gt_mask_params}')
+
+    gt_angles = []
+    # for gt_mask_end,gt_mask_param in zip(gt_mask_ends, gt_mask_params):
+    #     print(f'gt_mask_end:{gt_mask_end}, gt_mask_param:{gt_mask_param}')
+    #     gt_angles.append(compute_arc_angles(gt_mask_end,gt_mask_param))
+    for i in range(len(gt_mask_ends)):
+        print(f'gt_mask_end:{gt_mask_ends[i]}, gt_mask_param:{gt_mask_params[i]}')
+        gt_angles.append(compute_arc_angles(gt_mask_ends[i], gt_mask_params[i]))
+
+    print(f'gt_angles:{gt_angles}')
+    print(f'gt_mask_params:{gt_mask_params}')
+    print(f'gt_labels_all:{gt_labels_all}')
+    print(f'arc_pos_matched_idxs:{arc_pos_matched_idxs}')
+
+    gt_sel_params = []
+    gt_sel_angles = []
+    for proposals_per_image, gt_angle, gt_params, gt_label, midx in zip(proposals, gt_angles, gt_mask_params,
+                                                                        gt_labels_all, arc_pos_matched_idxs):
+        print(f'line_proposals_per_image:{proposals_per_image.shape}')
+        # gt_angle = torch.tensor(gt_angle)
+        gt_angle = torch.stack(gt_angle, dim=0)
+        gt_params = torch.tensor(gt_params)
+        if gt_angle.shape[0] > 0:
+            # positions = (gt_label == 3).nonzero()[0].item()
+
+            po = gt_angle[midx.cpu()]
+            pa = gt_params[midx.cpu()]
+            print(f'po:{po},pa:{pa}')
+
+            gt_sel_angles.append(po)
+            gt_sel_params.append(pa)
+
+    print(f'gt_sel_angles:{gt_sel_angles}')
+    print(f'gt_sel_params:{gt_sel_params}')
+
+    gt_sel_angles = torch.cat(gt_sel_angles, dim=0)
+    gt_sel_params = torch.cat(gt_sel_params, dim=0)
+
+    print(f'gt_sel_angles:{gt_sel_angles}')
+    print(f'gt_sel_params:{gt_sel_params}')
+
+    pred_angles = arc_equation[:, 5:7]
+    pred_params = arc_equation[:, :5]
+
+    angle_loss = F.mse_loss(pred_angles, gt_sel_angles)
+    param_loss = F.mse_loss(pred_params.cpu(), gt_sel_params) / 10000
+    print(f'angle_loss:{angle_loss}, param_loss:{param_loss}')
+
+    count = sum(len(sublist) for sublist in proposals)
+
+    total_loss = (param_loss + angle_loss) / count if count > 0 else torch.tensor(0.0)
+
+    # 确保 dtype 和 device
+    total_loss = total_loss.float().to(device)
+    angle_loss = angle_loss.float().to(device)
+    param_loss = param_loss.float().to(device)
+
+    # if count > 0:
+    #     total_loss = (param_loss + angle_loss) / count
+    #     total_loss = torch.tensor(total_loss, dtype=torch.float32, device=device)
+    # else:
+    #     total_loss = torch.tensor(0.0, dtype=torch.float32, device=device)
+
+    print(f'total_loss, param_loss, angle_loss:{total_loss, param_loss, angle_loss}')
+
+    return total_loss, param_loss, angle_loss
+
+
+def compute_arc_angles(gt_mask_ends, gt_mask_params):
+    """
+    给定椭圆上的一个点,计算其对应的参数角 phi(弧度)。
+
+    Parameters:
+        point: tuple or array-like, (x, y)
+        ellipse_param: tuple or array-like, (xc, yc, a, b, theta)
+
+    Returns:
+        phi: float, in [0, 2*pi)
+    """
+    results = []
+    gt_mask_params_tensor = torch.tensor(gt_mask_params,
+                                         dtype=gt_mask_ends.dtype,
+                                         device=gt_mask_ends.device)
+    for ends_img, params_img in zip(gt_mask_ends, gt_mask_params_tensor):
+        # print(f'params_img:{params_img}')
+        if torch.norm(params_img) < 1e-6:  # L2 norm near zero
+            results.append(torch.zeros(2, device=params_img.device, dtype=params_img.dtype))
+            continue
+        x, y = ends_img
+        xc, yc, a, b, theta = params_img
+
+        # 1. 平移到中心
+        dx = x - xc
+        dy = y - yc
+
+        # 2. 逆旋转(旋转 -theta)
+        cos_t = torch.cos(theta)
+        sin_t = torch.sin(theta)
+        X = dx * cos_t + dy * sin_t
+        Y = -dx * sin_t + dy * cos_t
+
+        # 3. 归一化到单位圆(除以 a, b)
+        cos_phi = X / a
+        sin_phi = Y / b
+
+        # 4. 用 atan2 求角度(自动处理象限)
+        phi = torch.atan2(sin_phi, cos_phi)
+
+        # 5. 转换到 [0, 2π)
+        phi = torch.where(phi < 0, phi + 2 * torch.pi, phi)
+
+        results.append(phi)
+    return results