Browse Source

重构代码,分离损失计算相关函数

RenLiqiang 5 months ago
parent
commit
f8de422289

+ 0 - 0
models/line_detect/heads/__init__.py


+ 601 - 0
models/line_detect/heads/head_losses.py

@@ -0,0 +1,601 @@
+import torch
+from matplotlib import pyplot as plt
+
+import torch.nn.functional as F
+
+def features_align(features, proposals, img_size):
+    print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}')
+
+    align_feat_list = []
+
+    for feat, proposals_per_img in zip(features, proposals):
+        print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
+        if proposals_per_img.shape[0]>0:
+            feat = feat.unsqueeze(0)
+            for proposal in proposals_per_img:
+                align_feat = torch.zeros_like(feat)
+                # print(f'align_feat:{align_feat.shape}')
+                x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
+                # 将每个proposal框内的部分赋值到align_feats对应位置
+                align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
+                align_feat_list.append(align_feat)
+
+    # print(f'align_feat_list:{align_feat_list}')
+    if len(align_feat_list) > 0:
+        feats_tensor = torch.cat(align_feat_list)
+
+        print(f'align features :{feats_tensor.shape}')
+    else:
+        feats_tensor = None
+
+    return feats_tensor
+def normalize_tensor(t):
+    return (t - t.min()) / (t.max() - t.min() + 1e-6)
+
+def line_length(lines):
+    """
+    计算每条线段的长度
+    lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
+    返回: [N]
+    """
+    return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
+
+def line_direction(lines):
+    """
+    计算每条线段的单位方向向量
+    lines: [N, 2, 2]
+    返回: [N, 2] 单位方向向量
+    """
+    vec = lines[:, 1] - lines[:, 0]
+    return F.normalize(vec, dim=-1)
+
+def angle_loss_cosine(pred_dir, gt_dir):
+    """
+    使用 cosine similarity 计算方向差异
+    pred_dir: [N, 2]
+    gt_dir: [N, 2]
+    返回: [N]
+    """
+    cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
+    return 1.0 - cos_sim  # 或者 torch.acos(cos_sim) / pi 也可
+
+
+def line_length(lines):
+        """
+        计算每条线段的长度
+        lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
+        返回: [N]
+        """
+        return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
+
+def line_direction(lines):
+        """
+        计算每条线段的单位方向向量
+        lines: [N, 2, 2]
+        返回: [N, 2] 单位方向向量
+        """
+        vec = lines[:, 1] - lines[:, 0]
+        return F.normalize(vec, dim=-1)
+
+def angle_loss_cosine(pred_dir, gt_dir):
+        """
+        使用 cosine similarity 计算方向差异
+        pred_dir: [N, 2]
+        gt_dir: [N, 2]
+        返回: [N]
+        """
+        cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
+        return 1.0 - cos_sim  # 或者 torch.acos(cos_sim) / pi 也可
+
+
+def single_point_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tensor
+    print(f'rois:{rois.shape}')
+    print(f'heatmap_size:{heatmap_size}')
+
+
+    print(f'keypoints.shape:{keypoints.shape}')
+    # batch_size, num_keypoints, _ = keypoints.shape
+
+    x = keypoints[..., 0].unsqueeze(1)
+    y = keypoints[..., 1].unsqueeze(1)
+
+
+    gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
+    # show_heatmap(gs[0],'target')
+    all_roi_heatmap = []
+    for roi, heatmap in zip(rois, gs):
+        # show_heatmap(heatmap, 'target')
+        # print(f'heatmap:{heatmap.shape}')
+        heatmap = heatmap.unsqueeze(0)
+        x1, y1, x2, y2 = map(int, roi)
+        roi_heatmap = torch.zeros_like(heatmap)
+        roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
+        # show_heatmap(roi_heatmap[0],'roi_heatmap')
+        all_roi_heatmap.append(roi_heatmap)
+
+    all_roi_heatmap = torch.cat(all_roi_heatmap)
+    print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
+
+    return all_roi_heatmap
+
+def line_points_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tensor
+    print(f'rois:{rois.shape}')
+    print(f'heatmap_size:{heatmap_size}')
+
+
+    print(f'keypoints.shape:{keypoints.shape}')
+    # batch_size, num_keypoints, _ = keypoints.shape
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    gs = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
+    # show_heatmap(gs[0],'target')
+    all_roi_heatmap = []
+    for roi, heatmap in zip(rois, gs):
+        # print(f'heatmap:{heatmap.shape}')
+        heatmap = heatmap.unsqueeze(0)
+        x1, y1, x2, y2 = map(int, roi)
+        roi_heatmap = torch.zeros_like(heatmap)
+        roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
+        # show_heatmap(roi_heatmap,'roi_heatmap')
+        all_roi_heatmap.append(roi_heatmap)
+
+    all_roi_heatmap = torch.cat(all_roi_heatmap)
+    print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
+
+    return all_roi_heatmap
+
+
+"""
+修改适配的原结构的点 转热图,适用于带roi_pool版本的
+"""
+
+
+def line_points_to_heatmap_(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    print(f'rois:{rois.shape}')
+    print(f'heatmap_size:{heatmap_size}')
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    print(f'keypoints.shape:{keypoints.shape}')
+    # batch_size, num_keypoints, _ = keypoints.shape
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    # gs=generate_gaussian_heatmaps(x,y,512,1.0)
+
+    # print(f'gs_heatmap shape:{gs.shape}')
+    #
+    # show_heatmap(gs[0],'target')
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+    # print(f'heatmaps x:{x}')
+    # print(f'heatmaps y:{y}')
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
+
+    show_heatmap(gs_heatmap[0], 'feature')
+
+    # print(f'gs_heatmap:{gs_heatmap.shape}')
+    #
+    # lin_ind = y * heatmap_size + x
+    # print(f'lin_ind:{lin_ind.shape}')
+    # heatmaps = lin_ind * valid
+
+    return gs_heatmap
+
+
+def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
+    """
+    为一组点生成并合并高斯热图。
+
+    Args:
+        xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
+        ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
+        heatmap_size (int): 热图大小 H=W
+        sigma (float): 高斯核标准差
+        device (str): 设备类型 ('cpu' or 'cuda')
+
+    Returns:
+        Tensor: 形状为 (H, W) 的合并后的热图
+    """
+
+    assert xs.shape == ys.shape, "x and y must have the same shape"
+    print(f'xs:{xs.shape}')
+    N = xs.shape[0]
+    print(f'N:{N},num_points:{num_points}')
+
+    # 创建网格
+    grid_y, grid_x = torch.meshgrid(
+        torch.arange(heatmap_size, device=device),
+        torch.arange(heatmap_size, device=device),
+        indexing='ij'
+    )
+
+    # print(f'heatmap_size:{heatmap_size}')
+    # 初始化输出热图
+    combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
+
+    for i in range(N):
+        heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
+        for j in range(num_points):
+            mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
+            mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
+            # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
+
+            # 计算距离平方
+            dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
+
+            # 计算高斯分布
+            heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
+
+            heatmap+=heatmap1
+
+
+        # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
+        # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
+        #
+        # # 计算距离平方
+        # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
+        #
+        # # 计算高斯分布
+        # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
+        #
+        # heatmap = heatmap1 + heatmap2
+
+        # 将当前热图累加到结果中
+        combined_heatmap[i] = heatmap
+
+    return combined_heatmap
+
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+
+def heatmaps_to_points(maps, rois):
+
+
+    point_preds = torch.zeros((len(rois),  2), dtype=torch.float32, device=maps.device)
+    point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device)
+
+    print(f'heatmaps_to_lines:{maps.shape}')
+    point_maps=maps[:,0]
+    print(f'point_map:{point_maps.shape}')
+    for i in range(len(rois)):
+
+        point_roi_map = point_maps[i].unsqueeze(0)
+        print(f'point_roi_map:{point_roi_map.shape}')
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = point_roi_map.shape[2]
+        flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
+        point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
+        print(f'point index:{point_index}')
+        # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        point_x =point_index % w
+        point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
+        point_preds[i, 0,] = point_x
+        point_preds[i, 1,] = point_y
+
+        point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
+
+
+    return point_preds,point_end_scores
+
+def heatmaps_to_lines(maps, rois):
+    line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
+    line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
+
+    line_maps=maps[:,1]
+
+
+    for i in range(len(rois)):
+        line_roi_map = line_maps[i].unsqueeze(0)
+
+        print(f'line_roi_map:{line_roi_map.shape}')
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = line_roi_map.shape[1]
+        flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1)
+        line_score, line_index = torch.topk(flatten_line_roi_map, k=2)
+        print(f'line index:{line_index}')
+        # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+        pos = line_index
+        line_x = pos % w
+        line_y = torch.div(pos - line_x, w, rounding_mode="floor")
+        line_preds[i, 0, :] = line_x
+        line_preds[i, 1, :] = line_y
+        line_preds[i, 2, :] = 1
+        line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x]
+
+
+
+
+    return line_preds.permute(0, 2, 1), line_end_scores
+
+
+# 显示热图的函数
+def show_heatmap(heatmap, title="Heatmap"):
+    """
+    使用 matplotlib 显示热图。
+
+    Args:
+        heatmap (Tensor): 要显示的热图张量
+        title (str): 图表标题
+    """
+    # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组
+    if heatmap.is_cuda:
+        heatmap = heatmap.cpu().numpy()
+    else:
+        heatmap = heatmap.numpy()
+
+    plt.imshow(heatmap, cmap='hot', interpolation='nearest')
+    plt.colorbar()
+    plt.title(title)
+    plt.show()
+
+def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = line_logits.shape
+    len_proposals = len(proposals)
+    print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
+    if H != W:
+        raise ValueError(
+            f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    gs_heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
+        print(f'line_proposals_per_image:{proposals_per_image.shape}')
+        print(f'gt_lines:{gt_lines}')
+        kp = gt_kp_in_image[midx]
+        gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
+        gs_heatmaps.append(gs_heatmaps_per_img)
+        # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
+
+        # heatmaps.append(heatmaps_per_image.view(-1))
+
+        # valid.append(valid_per_image.view(-1))
+
+    # line_targets = torch.cat(heatmaps, dim=0)
+    gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
+    print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
+    # print(f'line_targets:{line_targets.shape},{line_targets}')
+
+    # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    # valid = torch.where(valid)[0]
+
+    # print(f' line_targets[valid]:{line_targets[valid]}')
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    # if line_targets.numel() == 0 or len(valid) == 0:
+    #     return line_logits.sum() * 0
+
+    # line_logits = line_logits.view(N * K, H * W)
+    # print(f'line_logits[valid]:{line_logits[valid].shape}')
+    line_logits = line_logits.squeeze(1)
+
+    # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
+    line_loss = F.cross_entropy(line_logits, gs_heatmaps)
+
+    return line_loss
+
+
+def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = line_logits.shape
+    len_proposals = len(proposals)
+
+    empty_count = 0
+    non_empty_count = 0
+
+    for prop in proposals:
+        if prop.shape[0] == 0:
+            empty_count += 1
+        else:
+            non_empty_count += 1
+
+    print(f"Empty proposals count: {empty_count}")
+    print(f"Non-empty proposals count: {non_empty_count}")
+
+    print(f'starte to compute_point_loss')
+    print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
+    if H != W:
+        raise ValueError(
+            f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+
+    gs_heatmaps = []
+    # print(f'point_matched_idxs:{point_matched_idxs}')
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs):
+        print(f'proposals_per_image:{proposals_per_image.shape}')
+        kp = gt_kp_in_image[midx]
+        # print(f'gt_kp_in_image:{gt_kp_in_image}')
+        gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size)
+        gs_heatmaps.append(gs_heatmaps_per_img)
+
+    gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
+    print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
+
+    line_logits = line_logits[:,0]
+    print(f'single_point_logits:{line_logits.shape}')
+
+    line_loss = F.cross_entropy(line_logits, gs_heatmaps)
+
+    return line_loss
+
+def lines_to_boxes(lines, img_size=511):
+    """
+    输入:
+        lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
+        img_size: int,图像尺寸,用于 clamp 边界
+
+    输出:
+        boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
+    """
+    # 提取所有线段的两个端点
+    p1 = lines[:, 0]  # (N, 2)
+    p2 = lines[:, 1]  # (N, 2)
+
+    # 每条线段的 x 和 y 坐标
+    x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1)  # (N, 2)
+    y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1)  # (N, 2)
+
+    # 计算包围盒边界
+    x_min = x_coords.min(dim=1).values
+    y_min = y_coords.min(dim=1).values
+    x_max = x_coords.max(dim=1).values
+    y_max = y_coords.max(dim=1).values
+
+    # 扩展边界并限制在图像范围内
+    x_min = (x_min - 1).clamp(min=0, max=img_size)
+    y_min = (y_min - 1).clamp(min=0, max=img_size)
+    x_max = (x_max + 1).clamp(min=0, max=img_size)
+    y_max = (y_max + 1).clamp(min=0, max=img_size)
+
+    # 合成包围盒
+    boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1)  # (N, 4)
+    return boxes
+
+
+def box_iou_pairwise(box1, box2):
+    """
+    输入:
+        box1: shape (N, 4)
+        box2: shape (M, 4)
+    输出:
+        ious: shape (min(N, M), ), 只计算 i = j 的配对
+    """
+    N = min(len(box1), len(box2))
+    lt = torch.max(box1[:N, :2], box2[:N, :2])  # 左上角
+    rb = torch.min(box1[:N, 2:], box2[:N, 2:])  # 右下角
+
+    wh = (rb - lt).clamp(min=0)  # 宽高
+    inter_area = wh[:, 0] * wh[:, 1]  # 交集面积
+
+    area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
+    area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
+
+    union_area = area1 + area2 - inter_area
+    ious = inter_area / (union_area + 1e-6)
+
+    return ious
+
+
+def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0):
+    """
+    Args:
+        x: [N,1,H,W] 热力图
+        boxes: [N,4] 框坐标
+        gt_lines: [N,2,3] GT线段(含可见性)
+        matched_idx: 匹配 index
+        img_size: 图像尺寸
+        alpha: IoU 损失权重
+        beta: 长度损失权重
+        gamma: 方向角度损失权重
+    """
+    losses = []
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
+        p_prob, _ = heatmaps_to_lines(xx, bb)
+        pred_lines = p_prob
+        gt_line_points = gt_line[mid]
+
+        if len(pred_lines) == 0 or len(gt_line_points) == 0:
+            continue
+
+        # IoU 损失
+        pred_boxes = lines_to_boxes(pred_lines, img_size)
+        gt_boxes = lines_to_boxes(gt_line_points, img_size)
+        ious = box_iou_pairwise(pred_boxes, gt_boxes)
+        iou_loss = 1.0 - ious  # [N]
+
+        # 长度损失
+        pred_len = line_length(pred_lines)
+        gt_len = line_length(gt_line_points)
+        length_diff = F.l1_loss(pred_len, gt_len, reduction='none')  # [N]
+
+        # 方向角度损失
+        pred_dir = line_direction(pred_lines)
+        gt_dir = line_direction(gt_line_points)
+        ang_loss = angle_loss_cosine(pred_dir, gt_dir)  # [N]
+
+        # 归一化每一项损失
+        norm_iou = normalize_tensor(iou_loss)
+        norm_len = normalize_tensor(length_diff)
+        norm_ang = normalize_tensor(ang_loss)
+
+        total = alpha * norm_iou + beta * norm_len + gamma * norm_ang
+        losses.append(total)
+
+
+
+    if not losses:
+        return None
+
+    return torch.mean(torch.cat(losses))
+
+
+def point_inference(x, point_boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+
+    points_probs = []
+    points_scores = []
+
+    boxes_per_image = [box.size(0) for box in point_boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, point_boxes):
+        point_prob,point_scores = heatmaps_to_points(xx, bb)
+
+        points_probs.append(point_prob.unsqueeze(1))
+        points_scores.append(point_scores)
+
+    return points_probs,points_scores
+
+def line_inference(x, line_boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    lines_probs = []
+    lines_scores = []
+
+    boxes_per_image = [box.size(0) for box in line_boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, line_boxes):
+        line_prob, line_scores, = heatmaps_to_lines(xx, bb)
+        lines_probs.append(line_prob)
+        lines_scores.append(line_scores)
+
+    return lines_probs, lines_scores

+ 0 - 0
models/line_detect/line_heads.py → models/line_detect/heads/line_heads.py


+ 17 - 643
models/line_detect/loi_heads.py

@@ -12,6 +12,9 @@ import libs.vision_libs.models.detection._utils as det_utils
 
 from collections import OrderedDict
 
+from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
+    lines_point_pair_loss, features_align
+
 
 def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
     # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
@@ -129,270 +132,6 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
     )
     return mask_loss
 
-def normalize_tensor(t):
-    return (t - t.min()) / (t.max() - t.min() + 1e-6)
-
-def line_length(lines):
-    """
-    计算每条线段的长度
-    lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
-    返回: [N]
-    """
-    return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
-
-def line_direction(lines):
-    """
-    计算每条线段的单位方向向量
-    lines: [N, 2, 2]
-    返回: [N, 2] 单位方向向量
-    """
-    vec = lines[:, 1] - lines[:, 0]
-    return F.normalize(vec, dim=-1)
-
-def angle_loss_cosine(pred_dir, gt_dir):
-    """
-    使用 cosine similarity 计算方向差异
-    pred_dir: [N, 2]
-    gt_dir: [N, 2]
-    返回: [N]
-    """
-    cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
-    return 1.0 - cos_sim  # 或者 torch.acos(cos_sim) / pi 也可
-
-
-def line_length(lines):
-        """
-        计算每条线段的长度
-        lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
-        返回: [N]
-        """
-        return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
-
-def line_direction(lines):
-        """
-        计算每条线段的单位方向向量
-        lines: [N, 2, 2]
-        返回: [N, 2] 单位方向向量
-        """
-        vec = lines[:, 1] - lines[:, 0]
-        return F.normalize(vec, dim=-1)
-
-def angle_loss_cosine(pred_dir, gt_dir):
-        """
-        使用 cosine similarity 计算方向差异
-        pred_dir: [N, 2]
-        gt_dir: [N, 2]
-        返回: [N]
-        """
-        cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
-        return 1.0 - cos_sim  # 或者 torch.acos(cos_sim) / pi 也可
-
-
-def single_point_to_heatmap(keypoints, rois, heatmap_size):
-    # type: (Tensor, Tensor, int) -> Tensor
-    print(f'rois:{rois.shape}')
-    print(f'heatmap_size:{heatmap_size}')
-
-
-    print(f'keypoints.shape:{keypoints.shape}')
-    # batch_size, num_keypoints, _ = keypoints.shape
-
-    x = keypoints[..., 0].unsqueeze(1)
-    y = keypoints[..., 1].unsqueeze(1)
-
-
-    gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
-    # show_heatmap(gs[0],'target')
-    all_roi_heatmap = []
-    for roi, heatmap in zip(rois, gs):
-        # show_heatmap(heatmap, 'target')
-        # print(f'heatmap:{heatmap.shape}')
-        heatmap = heatmap.unsqueeze(0)
-        x1, y1, x2, y2 = map(int, roi)
-        roi_heatmap = torch.zeros_like(heatmap)
-        roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
-        # show_heatmap(roi_heatmap[0],'roi_heatmap')
-        all_roi_heatmap.append(roi_heatmap)
-
-    all_roi_heatmap = torch.cat(all_roi_heatmap)
-    print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
-
-    return all_roi_heatmap
-
-def line_points_to_heatmap(keypoints, rois, heatmap_size):
-    # type: (Tensor, Tensor, int) -> Tensor
-    print(f'rois:{rois.shape}')
-    print(f'heatmap_size:{heatmap_size}')
-
-
-    print(f'keypoints.shape:{keypoints.shape}')
-    # batch_size, num_keypoints, _ = keypoints.shape
-
-    x = keypoints[..., 0]
-    y = keypoints[..., 1]
-
-    gs = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
-    # show_heatmap(gs[0],'target')
-    all_roi_heatmap = []
-    for roi, heatmap in zip(rois, gs):
-        # print(f'heatmap:{heatmap.shape}')
-        heatmap = heatmap.unsqueeze(0)
-        x1, y1, x2, y2 = map(int, roi)
-        roi_heatmap = torch.zeros_like(heatmap)
-        roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
-        # show_heatmap(roi_heatmap,'roi_heatmap')
-        all_roi_heatmap.append(roi_heatmap)
-
-    all_roi_heatmap = torch.cat(all_roi_heatmap)
-    print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
-
-    return all_roi_heatmap
-
-
-"""
-修改适配的原结构的点 转热图,适用于带roi_pool版本的
-"""
-
-
-def line_points_to_heatmap_(keypoints, rois, heatmap_size):
-    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
-    print(f'rois:{rois.shape}')
-    print(f'heatmap_size:{heatmap_size}')
-    offset_x = rois[:, 0]
-    offset_y = rois[:, 1]
-    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
-    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
-
-    offset_x = offset_x[:, None]
-    offset_y = offset_y[:, None]
-    scale_x = scale_x[:, None]
-    scale_y = scale_y[:, None]
-
-    print(f'keypoints.shape:{keypoints.shape}')
-    # batch_size, num_keypoints, _ = keypoints.shape
-
-    x = keypoints[..., 0]
-    y = keypoints[..., 1]
-
-    # gs=generate_gaussian_heatmaps(x,y,512,1.0)
-
-    # print(f'gs_heatmap shape:{gs.shape}')
-    #
-    # show_heatmap(gs[0],'target')
-
-    x_boundary_inds = x == rois[:, 2][:, None]
-    y_boundary_inds = y == rois[:, 3][:, None]
-
-    x = (x - offset_x) * scale_x
-    x = x.floor().long()
-    y = (y - offset_y) * scale_y
-    y = y.floor().long()
-
-    x[x_boundary_inds] = heatmap_size - 1
-    y[y_boundary_inds] = heatmap_size - 1
-    # print(f'heatmaps x:{x}')
-    # print(f'heatmaps y:{y}')
-
-    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
-    vis = keypoints[..., 2] > 0
-    valid = (valid_loc & vis).long()
-
-    gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
-
-    show_heatmap(gs_heatmap[0], 'feature')
-
-    # print(f'gs_heatmap:{gs_heatmap.shape}')
-    #
-    # lin_ind = y * heatmap_size + x
-    # print(f'lin_ind:{lin_ind.shape}')
-    # heatmaps = lin_ind * valid
-
-    return gs_heatmap
-
-
-def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
-    """
-    为一组点生成并合并高斯热图。
-
-    Args:
-        xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
-        ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
-        heatmap_size (int): 热图大小 H=W
-        sigma (float): 高斯核标准差
-        device (str): 设备类型 ('cpu' or 'cuda')
-
-    Returns:
-        Tensor: 形状为 (H, W) 的合并后的热图
-    """
-
-    assert xs.shape == ys.shape, "x and y must have the same shape"
-    print(f'xs:{xs.shape}')
-    N = xs.shape[0]
-    print(f'N:{N},num_points:{num_points}')
-
-    # 创建网格
-    grid_y, grid_x = torch.meshgrid(
-        torch.arange(heatmap_size, device=device),
-        torch.arange(heatmap_size, device=device),
-        indexing='ij'
-    )
-
-    # print(f'heatmap_size:{heatmap_size}')
-    # 初始化输出热图
-    combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
-
-    for i in range(N):
-        heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
-        for j in range(num_points):
-            mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
-            mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
-            # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
-
-            # 计算距离平方
-            dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
-
-            # 计算高斯分布
-            heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
-
-            heatmap+=heatmap1
-
-
-        # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
-        # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
-        #
-        # # 计算距离平方
-        # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
-        #
-        # # 计算高斯分布
-        # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
-        #
-        # heatmap = heatmap1 + heatmap2
-
-        # 将当前热图累加到结果中
-        combined_heatmap[i] = heatmap
-
-    return combined_heatmap
-
-
-# 显示热图的函数
-def show_heatmap(heatmap, title="Heatmap"):
-    """
-    使用 matplotlib 显示热图。
-
-    Args:
-        heatmap (Tensor): 要显示的热图张量
-        title (str): 图表标题
-    """
-    # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组
-    if heatmap.is_cuda:
-        heatmap = heatmap.cpu().numpy()
-    else:
-        heatmap = heatmap.numpy()
-
-    plt.imshow(heatmap, cmap='hot', interpolation='nearest')
-    plt.colorbar()
-    plt.title(title)
-    plt.show()
 
 
 def keypoints_to_heatmap(keypoints, rois, heatmap_size):
@@ -564,340 +303,6 @@ def heatmaps_to_keypoints(maps, rois):
     return xy_preds.permute(0, 2, 1), end_scores
 
 
-def non_maximum_suppression(a):
-    ap = F.max_pool2d(a, 3, stride=1, padding=1)
-    mask = (a == ap).float().clamp(min=0.0)
-    return a * mask
-
-def heatmaps_to_points(maps, rois):
-
-
-    point_preds = torch.zeros((len(rois),  2), dtype=torch.float32, device=maps.device)
-    point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device)
-
-    print(f'heatmaps_to_lines:{maps.shape}')
-    point_maps=maps[:,0]
-    print(f'point_map:{point_maps.shape}')
-    for i in range(len(rois)):
-
-        point_roi_map = point_maps[i].unsqueeze(0)
-        print(f'point_roi_map:{point_roi_map.shape}')
-        # roi_map_probs = scores_to_probs(roi_map.copy())
-        w = point_roi_map.shape[2]
-        flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
-        point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
-        print(f'point index:{point_index}')
-        # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
-
-        point_x =point_index % w
-        point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
-        point_preds[i, 0,] = point_x
-        point_preds[i, 1,] = point_y
-
-        point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
-
-
-    return point_preds,point_end_scores
-
-def heatmaps_to_lines(maps, rois):
-    line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
-    line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
-
-    line_maps=maps[:,1]
-
-
-    for i in range(len(rois)):
-        line_roi_map = line_maps[i].unsqueeze(0)
-
-        print(f'line_roi_map:{line_roi_map.shape}')
-        # roi_map_probs = scores_to_probs(roi_map.copy())
-        w = line_roi_map.shape[1]
-        flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1)
-        line_score, line_index = torch.topk(flatten_line_roi_map, k=2)
-        print(f'line index:{line_index}')
-        # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
-        pos = line_index
-        line_x = pos % w
-        line_y = torch.div(pos - line_x, w, rounding_mode="floor")
-        line_preds[i, 0, :] = line_x
-        line_preds[i, 1, :] = line_y
-        line_preds[i, 2, :] = 1
-        line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x]
-
-
-
-
-    return line_preds.permute(0, 2, 1), line_end_scores
-
-
-def features_align(features, proposals, img_size):
-    print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}')
-
-    align_feat_list = []
-
-    for feat, proposals_per_img in zip(features, proposals):
-        print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
-        if proposals_per_img.shape[0]>0:
-            feat = feat.unsqueeze(0)
-            for proposal in proposals_per_img:
-                align_feat = torch.zeros_like(feat)
-                # print(f'align_feat:{align_feat.shape}')
-                x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
-                # 将每个proposal框内的部分赋值到align_feats对应位置
-                align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
-                align_feat_list.append(align_feat)
-
-    # print(f'align_feat_list:{align_feat_list}')
-    if len(align_feat_list) > 0:
-        feats_tensor = torch.cat(align_feat_list)
-
-        print(f'align features :{feats_tensor.shape}')
-    else:
-        feats_tensor = None
-
-    return feats_tensor
-
-
-def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
-    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
-    N, K, H, W = line_logits.shape
-    len_proposals = len(proposals)
-    print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
-    if H != W:
-        raise ValueError(
-            f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
-        )
-    discretization_size = H
-    heatmaps = []
-    gs_heatmaps = []
-    valid = []
-    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
-        print(f'line_proposals_per_image:{proposals_per_image.shape}')
-        print(f'gt_lines:{gt_lines}')
-        kp = gt_kp_in_image[midx]
-        gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
-        gs_heatmaps.append(gs_heatmaps_per_img)
-        # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
-
-        # heatmaps.append(heatmaps_per_image.view(-1))
-
-        # valid.append(valid_per_image.view(-1))
-
-    # line_targets = torch.cat(heatmaps, dim=0)
-    gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
-    print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
-    # print(f'line_targets:{line_targets.shape},{line_targets}')
-
-    # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
-    # valid = torch.where(valid)[0]
-
-    # print(f' line_targets[valid]:{line_targets[valid]}')
-
-    # torch.mean (in binary_cross_entropy_with_logits) doesn't
-    # accept empty tensors, so handle it sepaartely
-    # if line_targets.numel() == 0 or len(valid) == 0:
-    #     return line_logits.sum() * 0
-
-    # line_logits = line_logits.view(N * K, H * W)
-    # print(f'line_logits[valid]:{line_logits[valid].shape}')
-    line_logits = line_logits.squeeze(1)
-
-    # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
-    line_loss = F.cross_entropy(line_logits, gs_heatmaps)
-
-    return line_loss
-
-
-def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
-    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
-    N, K, H, W = line_logits.shape
-    len_proposals = len(proposals)
-
-    empty_count = 0
-    non_empty_count = 0
-
-    for prop in proposals:
-        if prop.shape[0] == 0:
-            empty_count += 1
-        else:
-            non_empty_count += 1
-
-    print(f"Empty proposals count: {empty_count}")
-    print(f"Non-empty proposals count: {non_empty_count}")
-
-    print(f'starte to compute_point_loss')
-    print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
-    if H != W:
-        raise ValueError(
-            f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
-        )
-    discretization_size = H
-
-    gs_heatmaps = []
-    # print(f'point_matched_idxs:{point_matched_idxs}')
-    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs):
-        print(f'proposals_per_image:{proposals_per_image.shape}')
-        kp = gt_kp_in_image[midx]
-        # print(f'gt_kp_in_image:{gt_kp_in_image}')
-        gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size)
-        gs_heatmaps.append(gs_heatmaps_per_img)
-
-    gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
-    print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
-
-    line_logits = line_logits[:,0]
-    print(f'single_point_logits:{line_logits.shape}')
-
-    line_loss = F.cross_entropy(line_logits, gs_heatmaps)
-
-    return line_loss
-
-def lines_to_boxes(lines, img_size=511):
-    """
-    输入:
-        lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
-        img_size: int,图像尺寸,用于 clamp 边界
-
-    输出:
-        boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
-    """
-    # 提取所有线段的两个端点
-    p1 = lines[:, 0]  # (N, 2)
-    p2 = lines[:, 1]  # (N, 2)
-
-    # 每条线段的 x 和 y 坐标
-    x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1)  # (N, 2)
-    y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1)  # (N, 2)
-
-    # 计算包围盒边界
-    x_min = x_coords.min(dim=1).values
-    y_min = y_coords.min(dim=1).values
-    x_max = x_coords.max(dim=1).values
-    y_max = y_coords.max(dim=1).values
-
-    # 扩展边界并限制在图像范围内
-    x_min = (x_min - 1).clamp(min=0, max=img_size)
-    y_min = (y_min - 1).clamp(min=0, max=img_size)
-    x_max = (x_max + 1).clamp(min=0, max=img_size)
-    y_max = (y_max + 1).clamp(min=0, max=img_size)
-
-    # 合成包围盒
-    boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1)  # (N, 4)
-    return boxes
-
-
-def box_iou_pairwise(box1, box2):
-    """
-    输入:
-        box1: shape (N, 4)
-        box2: shape (M, 4)
-    输出:
-        ious: shape (min(N, M), ), 只计算 i = j 的配对
-    """
-    N = min(len(box1), len(box2))
-    lt = torch.max(box1[:N, :2], box2[:N, :2])  # 左上角
-    rb = torch.min(box1[:N, 2:], box2[:N, 2:])  # 右下角
-
-    wh = (rb - lt).clamp(min=0)  # 宽高
-    inter_area = wh[:, 0] * wh[:, 1]  # 交集面积
-
-    area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
-    area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
-
-    union_area = area1 + area2 - inter_area
-    ious = inter_area / (union_area + 1e-6)
-
-    return ious
-
-
-def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0):
-    """
-    Args:
-        x: [N,1,H,W] 热力图
-        boxes: [N,4] 框坐标
-        gt_lines: [N,2,3] GT线段(含可见性)
-        matched_idx: 匹配 index
-        img_size: 图像尺寸
-        alpha: IoU 损失权重
-        beta: 长度损失权重
-        gamma: 方向角度损失权重
-    """
-    losses = []
-    boxes_per_image = [box.size(0) for box in boxes]
-    x2 = x.split(boxes_per_image, dim=0)
-
-    for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
-        p_prob, _ = heatmaps_to_lines(xx, bb)
-        pred_lines = p_prob
-        gt_line_points = gt_line[mid]
-
-        if len(pred_lines) == 0 or len(gt_line_points) == 0:
-            continue
-
-        # IoU 损失
-        pred_boxes = lines_to_boxes(pred_lines, img_size)
-        gt_boxes = lines_to_boxes(gt_line_points, img_size)
-        ious = box_iou_pairwise(pred_boxes, gt_boxes)
-        iou_loss = 1.0 - ious  # [N]
-
-        # 长度损失
-        pred_len = line_length(pred_lines)
-        gt_len = line_length(gt_line_points)
-        length_diff = F.l1_loss(pred_len, gt_len, reduction='none')  # [N]
-
-        # 方向角度损失
-        pred_dir = line_direction(pred_lines)
-        gt_dir = line_direction(gt_line_points)
-        ang_loss = angle_loss_cosine(pred_dir, gt_dir)  # [N]
-
-        # 归一化每一项损失
-        norm_iou = normalize_tensor(iou_loss)
-        norm_len = normalize_tensor(length_diff)
-        norm_ang = normalize_tensor(ang_loss)
-
-        total = alpha * norm_iou + beta * norm_len + gamma * norm_ang
-        losses.append(total)
-
-
-
-    if not losses:
-        return None
-
-    return torch.mean(torch.cat(losses))
-
-
-def point_inference(x, point_boxes):
-    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
-
-    points_probs = []
-    points_scores = []
-
-    boxes_per_image = [box.size(0) for box in point_boxes]
-    x2 = x.split(boxes_per_image, dim=0)
-
-    for xx, bb in zip(x2, point_boxes):
-        point_prob,point_scores = heatmaps_to_points(xx, bb)
-
-        points_probs.append(point_prob.unsqueeze(1))
-        points_scores.append(point_scores)
-
-    return points_probs,points_scores
-
-def line_inference(x, line_boxes):
-    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
-    lines_probs = []
-    lines_scores = []
-
-    boxes_per_image = [box.size(0) for box in line_boxes]
-    x2 = x.split(boxes_per_image, dim=0)
-
-    for xx, bb in zip(x2, line_boxes):
-        line_prob, line_scores, = heatmaps_to_lines(xx, bb)
-        lines_probs.append(line_prob)
-        lines_scores.append(line_scores)
-
-    return lines_probs, lines_scores
-
 
 def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
     # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
@@ -1528,42 +933,14 @@ class RoIHeads(nn.Module):
             point_proposals_tensor=torch.cat(point_proposals)
             print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
 
-
-            # line_features = lines_features_align(line_features, filtered_proposals, image_shapes)
-
             line_features=None
-            # line_features = features_align(cs_features, line_proposals, image_shapes)
-            # if line_features is not None:
-            #     print(f'line_features:{line_features.shape}')
-
-
-
-            # if line_features is not None and point_features is not None:
-            #     combine_features = torch.cat((point_features, line_features), dim=0)
-            # elif line_features  is not None:
-            #     combine_features =line_features
-            # elif point_features is not None:
-            #     combine_features =point_features
 
-            # combine_features = point_features
-            # print(f'line_features from features_align:{combine_features.shape}')
+            feature_logits = self.line_predictor(cs_features)
+            print(f'feature_logits from line_predictor:{feature_logits.shape}')
 
-            # combine_features = self.line_head(cs_features)
-
-
-
-            # if point_features is not None:
-            #     print(f'point_features:{point_features.shape}')
-
-            #(N,1,512,512)
-            # print(f'combine_features from line_head:{combine_features.shape}')
-
-            combine_features = self.line_predictor(cs_features )
-            print(f'combine_features from line_predictor:{combine_features.shape}')
-
-            point_features = features_align(combine_features, point_proposals, image_shapes)
-            print(f'point_features from  features_align:{point_features.shape}')
-            combine_features=point_features
+            point_features = features_align(feature_logits, point_proposals, image_shapes)
+            print(f'feature_logits  features_align:{point_features.shape}')
+            feature_logits=point_features
 
             # line_logits = combine_features
             # print(f'line_logits:{line_logits.shape}')
@@ -1581,10 +958,6 @@ class RoIHeads(nn.Module):
                 print(f'gt_lines:{gt_lines[0].shape}')
                 h, w = targets[0]["img_size"]
                 img_size = h
-                # rcnn_loss_line = lines_point_pair_loss(
-                #     line_logits, line_proposals, gt_lines, pos_matched_idxs
-                # )
-                # iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs, img_size)
 
                 gt_lines_tensor=torch.cat(gt_lines)
                 gt_points_tensor = torch.cat(gt_points)
@@ -1592,13 +965,13 @@ class RoIHeads(nn.Module):
                 print(f'gt_points_tensor:{gt_points_tensor.shape}')
                 if gt_lines_tensor.shape[0]>0  and line_features is not None:
                     loss_line = lines_point_pair_loss(
-                        combine_features, line_proposals, gt_lines, line_pos_matched_idxs
+                        feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
                     )
-                    loss_line_iou = line_iou_loss(combine_features, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
+                    loss_line_iou = line_iou_loss(feature_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
 
                 if gt_points_tensor.shape[0]>0 and point_features is not None:
                     loss_point = compute_point_loss(
-                        combine_features, point_proposals, gt_points, point_pos_matched_idxs
+                        feature_logits, point_proposals, gt_points, point_pos_matched_idxs
                     )
 
                 if not loss_line:
@@ -1625,14 +998,14 @@ class RoIHeads(nn.Module):
 
                     if gt_lines_tensor.shape[0] > 0 and line_features is not None:
                         loss_line = lines_point_pair_loss(
-                            combine_features, line_proposals, gt_lines, line_pos_matched_idxs
+                            feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
                         )
-                        loss_line_iou = line_iou_loss(combine_features, line_proposals, gt_lines, line_pos_matched_idxs,
+                        loss_line_iou = line_iou_loss(feature_logits, line_proposals, gt_lines, line_pos_matched_idxs,
                                                       img_size)
 
                     if gt_points_tensor.shape[0] > 0 and point_features is not None:
                         loss_point = compute_point_loss(
-                            combine_features, point_proposals, gt_points, point_pos_matched_idxs
+                            feature_logits, point_proposals, gt_points, point_pos_matched_idxs
                         )
 
                     if not loss_line :
@@ -1651,7 +1024,7 @@ class RoIHeads(nn.Module):
 
 
                 else:
-                    if combine_features is None or line_proposals is None:
+                    if feature_logits is None or line_proposals is None:
                         raise ValueError(
                             "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
                         )
@@ -1661,8 +1034,9 @@ class RoIHeads(nn.Module):
                     #     for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
                     #         r["lines"] = keypoint_prob
                     #         r["liness_scores"] = kps
+
                     if point_features is not None:
-                        point_probs, points_scores=point_inference(combine_features, point_proposals,)
+                        point_probs, points_scores=point_inference(feature_logits, point_proposals, )
                         for  points, ps, r in zip(point_probs,points_scores, result):
                             print(f'points_prob :{points.shape}')
 

+ 25 - 0
models/line_detect/test.py

@@ -0,0 +1,25 @@
+import torch
+
+from models.line_detect.line_detect import linedetect_newresnet18fpn, linedetect_resnet50_fpn, linedetect_resnet18_fpn
+
+
+from models.line_net.trainer import Trainer
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+if __name__ == '__main__':
+
+    # model = LineNet('line_net.yaml')
+    # model=linenet_resnet50_fpn()
+    # model = linedetect_resnet50_fpn()
+    # model=get_line_net_convnext_fpn(num_classes=2).to(device)
+    # model=linenet_newresnet50fpn()
+    # model = lineDetect_resnet18_fpn()
+
+    # model=linedetect_resnet18_fpn()
+    model=linedetect_newresnet18fpn(num_points=3)
+    model.eval()
+
+    input=torch.zeros((3,3,512,512))
+
+    out=model(input)
+    # model.start_train(cfg='train.yaml')

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
+  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000