import torch from matplotlib import pyplot as plt import torch.nn.functional as F def features_align(features, proposals, img_size): print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}') align_feat_list = [] for feat, proposals_per_img in zip(features, proposals): print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}') if proposals_per_img.shape[0]>0: feat = feat.unsqueeze(0) for proposal in proposals_per_img: align_feat = torch.zeros_like(feat) # print(f'align_feat:{align_feat.shape}') x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal) # 将每个proposal框内的部分赋值到align_feats对应位置 align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1] align_feat_list.append(align_feat) # print(f'align_feat_list:{align_feat_list}') if len(align_feat_list) > 0: feats_tensor = torch.cat(align_feat_list) print(f'align features :{feats_tensor.shape}') else: feats_tensor = None return feats_tensor def normalize_tensor(t): return (t - t.min()) / (t.max() - t.min() + 1e-6) def line_length(lines): """ 计算每条线段的长度 lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成 返回: [N] """ return torch.norm(lines[:, 1] - lines[:, 0], dim=-1) def line_direction(lines): """ 计算每条线段的单位方向向量 lines: [N, 2, 2] 返回: [N, 2] 单位方向向量 """ vec = lines[:, 1] - lines[:, 0] return F.normalize(vec, dim=-1) def angle_loss_cosine(pred_dir, gt_dir): """ 使用 cosine similarity 计算方向差异 pred_dir: [N, 2] gt_dir: [N, 2] 返回: [N] """ cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0) return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可 def line_length(lines): """ 计算每条线段的长度 lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成 返回: [N] """ return torch.norm(lines[:, 1] - lines[:, 0], dim=-1) def line_direction(lines): """ 计算每条线段的单位方向向量 lines: [N, 2, 2] 返回: [N, 2] 单位方向向量 """ vec = lines[:, 1] - lines[:, 0] return F.normalize(vec, dim=-1) def angle_loss_cosine(pred_dir, gt_dir): """ 使用 cosine similarity 计算方向差异 pred_dir: [N, 2] gt_dir: [N, 2] 返回: [N] """ cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0) return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可 def single_point_to_heatmap(keypoints, rois, heatmap_size): # type: (Tensor, Tensor, int) -> Tensor print(f'rois:{rois.shape}') print(f'heatmap_size:{heatmap_size}') print(f'keypoints.shape:{keypoints.shape}') # batch_size, num_keypoints, _ = keypoints.shape x = keypoints[..., 0].unsqueeze(1) y = keypoints[..., 1].unsqueeze(1) gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0) # show_heatmap(gs[0],'target') all_roi_heatmap = [] for roi, heatmap in zip(rois, gs): # show_heatmap(heatmap, 'target') # print(f'heatmap:{heatmap.shape}') heatmap = heatmap.unsqueeze(0) x1, y1, x2, y2 = map(int, roi) roi_heatmap = torch.zeros_like(heatmap) roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1] # show_heatmap(roi_heatmap[0],'roi_heatmap') all_roi_heatmap.append(roi_heatmap) all_roi_heatmap = torch.cat(all_roi_heatmap) print(f'all_roi_heatmap:{all_roi_heatmap.shape}') return all_roi_heatmap def line_points_to_heatmap(keypoints, rois, heatmap_size): # type: (Tensor, Tensor, int) -> Tensor print(f'rois:{rois.shape}') print(f'heatmap_size:{heatmap_size}') print(f'keypoints.shape:{keypoints.shape}') # batch_size, num_keypoints, _ = keypoints.shape x = keypoints[..., 0] y = keypoints[..., 1] gs = generate_gaussian_heatmaps(x, y, heatmap_size,num_points=2, sigma=1.0) # show_heatmap(gs[0],'target') all_roi_heatmap = [] for roi, heatmap in zip(rois, gs): # print(f'heatmap:{heatmap.shape}') # show_heatmap(heatmap,'target') heatmap = heatmap.unsqueeze(0) x1, y1, x2, y2 = map(int, roi) roi_heatmap = torch.zeros_like(heatmap) roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1] # show_heatmap(roi_heatmap[0],'roi_heatmap') all_roi_heatmap.append(roi_heatmap) if len(all_roi_heatmap) > 0: all_roi_heatmap = torch.cat(all_roi_heatmap) print(f'all_roi_heatmap:{all_roi_heatmap.shape}') else: all_roi_heatmap = None return all_roi_heatmap """ 修改适配的原结构的点 转热图,适用于带roi_pool版本的 """ def line_points_to_heatmap_(keypoints, rois, heatmap_size): # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor] print(f'rois:{rois.shape}') print(f'heatmap_size:{heatmap_size}') offset_x = rois[:, 0] offset_y = rois[:, 1] scale_x = heatmap_size / (rois[:, 2] - rois[:, 0]) scale_y = heatmap_size / (rois[:, 3] - rois[:, 1]) offset_x = offset_x[:, None] offset_y = offset_y[:, None] scale_x = scale_x[:, None] scale_y = scale_y[:, None] print(f'keypoints.shape:{keypoints.shape}') # batch_size, num_keypoints, _ = keypoints.shape x = keypoints[..., 0] y = keypoints[..., 1] # gs=generate_gaussian_heatmaps(x,y,512,1.0) # print(f'gs_heatmap shape:{gs.shape}') # # show_heatmap(gs[0],'target') x_boundary_inds = x == rois[:, 2][:, None] y_boundary_inds = y == rois[:, 3][:, None] x = (x - offset_x) * scale_x x = x.floor().long() y = (y - offset_y) * scale_y y = y.floor().long() x[x_boundary_inds] = heatmap_size - 1 y[y_boundary_inds] = heatmap_size - 1 # print(f'heatmaps x:{x}') # print(f'heatmaps y:{y}') valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size) vis = keypoints[..., 2] > 0 valid = (valid_loc & vis).long() gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0) # show_heatmap(gs_heatmap[0], 'feature') # print(f'gs_heatmap:{gs_heatmap.shape}') # # lin_ind = y * heatmap_size + x # print(f'lin_ind:{lin_ind.shape}') # heatmaps = lin_ind * valid return gs_heatmap def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'): """ 为一组点生成并合并高斯热图。 Args: xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标 ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标 heatmap_size (int): 热图大小 H=W sigma (float): 高斯核标准差 device (str): 设备类型 ('cpu' or 'cuda') Returns: Tensor: 形状为 (H, W) 的合并后的热图 """ assert xs.shape == ys.shape, "x and y must have the same shape" print(f'xs:{xs.shape}') # xs=xs.squeeze(1) # ys = ys.squeeze(1) print(f'xs1:{xs.shape}') N = xs.shape[0] print(f'N:{N},num_points:{num_points}') # 创建网格 grid_y, grid_x = torch.meshgrid( torch.arange(heatmap_size, device=device), torch.arange(heatmap_size, device=device), indexing='ij' ) # print(f'heatmap_size:{heatmap_size}') # 初始化输出热图 combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device) for i in range(N): heatmap= torch.zeros((heatmap_size, heatmap_size), device=device) for j in range(num_points): mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item() mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item() # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}') # 计算距离平方 dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2 # 计算高斯分布 heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2)) heatmap+=heatmap1 # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item() # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item() # # # 计算距离平方 # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2 # # # 计算高斯分布 # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2)) # # heatmap = heatmap1 + heatmap2 # 将当前热图累加到结果中 combined_heatmap[i] = heatmap return combined_heatmap def generate_mask_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'): """ 为一组点生成并合并高斯热图。 Args: xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标 ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标 heatmap_size (int): 热图大小 H=W sigma (float): 高斯核标准差 device (str): 设备类型 ('cpu' or 'cuda') Returns: Tensor: 形状为 (H, W) 的合并后的热图 """ assert xs.shape == ys.shape, "x and y must have the same shape" print(f'xs:{xs.shape}') xs=xs.squeeze(1) ys = ys.squeeze(1) print(f'xs1:{xs.shape}') N = xs.shape[0] print(f'N:{N},num_points:{num_points}') # 创建网格 grid_y, grid_x = torch.meshgrid( torch.arange(heatmap_size, device=device), torch.arange(heatmap_size, device=device), indexing='ij' ) # print(f'heatmap_size:{heatmap_size}') # 初始化输出热图 combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device) for i in range(N): heatmap= torch.zeros((heatmap_size, heatmap_size), device=device) for j in range(num_points): mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item() mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item() # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}') # 计算距离平方 dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2 # 计算高斯分布 heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2)) heatmap+=heatmap1 # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item() # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item() # # # 计算距离平方 # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2 # # # 计算高斯分布 # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2)) # # heatmap = heatmap1 + heatmap2 # 将当前热图累加到结果中 combined_heatmap[i] = heatmap return combined_heatmap def non_maximum_suppression(a): ap = F.max_pool2d(a, 3, stride=1, padding=1) mask = (a == ap).float().clamp(min=0.0) return a * mask def heatmaps_to_points(maps, rois): point_preds = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device) point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device) print(f'heatmaps_to_lines:{maps.shape}') point_maps=maps[:,0] print(f'point_map:{point_maps.shape}') for i in range(len(rois)): point_roi_map = point_maps[i].unsqueeze(0) print(f'point_roi_map:{point_roi_map.shape}') # roi_map_probs = scores_to_probs(roi_map.copy()) w = point_roi_map.shape[2] flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1) point_score, point_index = torch.topk(flatten_point_roi_map, k=1) print(f'point index:{point_index}') # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) point_x =point_index % w point_y = torch.div(point_index - point_x, w, rounding_mode="floor") point_preds[i, 0,] = point_x point_preds[i, 1,] = point_y point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x] return point_preds,point_end_scores def heatmaps_to_lines(maps, rois): line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device) line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device) line_maps=maps[:,1] for i in range(len(rois)): line_roi_map = line_maps[i].unsqueeze(0) print(f'line_roi_map:{line_roi_map.shape}') # roi_map_probs = scores_to_probs(roi_map.copy()) w = line_roi_map.shape[1] flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1) line_score, line_index = torch.topk(flatten_line_roi_map, k=2) print(f'line index:{line_index}') # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) pos = line_index line_x = pos % w line_y = torch.div(pos - line_x, w, rounding_mode="floor") line_preds[i, 0, :] = line_x line_preds[i, 1, :] = line_y line_preds[i, 2, :] = 1 line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x] return line_preds.permute(0, 2, 1), line_end_scores # 显示热图的函数 def show_heatmap(heatmap, title="Heatmap"): """ 使用 matplotlib 显示热图。 Args: heatmap (Tensor): 要显示的热图张量 title (str): 图表标题 """ # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组 if heatmap.is_cuda: heatmap = heatmap.cpu().numpy() else: heatmap = heatmap.numpy() plt.imshow(heatmap, cmap='hot', interpolation='nearest') plt.colorbar() plt.title(title) plt.show() def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs): # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor N, K, H, W = line_logits.shape len_proposals = len(proposals) print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals},line_matched_idxs:{line_matched_idxs}') if H != W: raise ValueError( f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}" ) discretization_size = H heatmaps = [] gs_heatmaps = [] valid = [] for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs): print(f'line_proposals_per_image:{proposals_per_image.shape}') print(f'gt_lines:{gt_lines}') if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0: kp = gt_kp_in_image[midx] gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size) gs_heatmaps.append(gs_heatmaps_per_img) # print(f'heatmaps_per_image:{heatmaps_per_image.shape}') # heatmaps.append(heatmaps_per_image.view(-1)) # valid.append(valid_per_image.view(-1)) # line_targets = torch.cat(heatmaps, dim=0) gs_heatmaps = torch.cat(gs_heatmaps, dim=0) print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}') # print(f'line_targets:{line_targets.shape},{line_targets}') # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8) # valid = torch.where(valid)[0] # print(f' line_targets[valid]:{line_targets[valid]}') # torch.mean (in binary_cross_entropy_with_logits) doesn't # accept empty tensors, so handle it sepaartely # if line_targets.numel() == 0 or len(valid) == 0: # return line_logits.sum() * 0 # line_logits = line_logits.view(N * K, H * W) # print(f'line_logits[valid]:{line_logits[valid].shape}') print(f'loss1 line_logits:{line_logits.shape}') line_logits = line_logits[:,1,:,:] print(f'loss2 line_logits:{line_logits.shape}') # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid]) line_loss = F.cross_entropy(line_logits, gs_heatmaps) return line_loss def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs): print(f'compute_arc_loss:{feature_logits.shape}') N, K, H, W = feature_logits.shape len_proposals = len(proposals) empty_count = 0 non_empty_count = 0 for prop in proposals: if prop.shape[0] == 0: empty_count += 1 else: non_empty_count += 1 print(f"Empty proposals count: {empty_count}") print(f"Non-empty proposals count: {non_empty_count}") print(f'starte to compute_point_loss') print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}') if H != W: raise ValueError( f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}" ) discretization_size = H gs_heatmaps = [] # print(f'point_matched_idxs:{point_matched_idxs}') for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs): print(f'proposals_per_image:{proposals_per_image.shape}') kp = gt_kp_in_image[midx] # print(f'gt_kp_in_image:{gt_kp_in_image}') gs_heatmaps_per_img = arc_points_to_heatmap(kp, proposals_per_image, discretization_size) gs_heatmaps.append(gs_heatmaps_per_img) gs_heatmaps = torch.cat(gs_heatmaps, dim=0) print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.squeeze(1).shape}') line_logits = feature_logits[:, 0] print(f'single_point_logits:{line_logits.shape}') line_loss = F.cross_entropy(line_logits, gs_heatmaps) return line_loss def arc_points_to_heatmap(keypoints, rois, heatmap_size): print(f'rois:{rois.shape}') print(f'heatmap_size:{heatmap_size}') print(f'keypoints.shape:{keypoints.shape}') # batch_size, num_keypoints, _ = keypoints.shape x = keypoints[..., 0].unsqueeze(1) y = keypoints[..., 1].unsqueeze(1) num_points=x.shape[2] print(f'num_points:{num_points}') gs = generate_mask_gaussian_heatmaps(x, y, num_points=num_points, heatmap_size=heatmap_size, sigma=1.0) # show_heatmap(gs[0],'target') all_roi_heatmap = [] for roi, heatmap in zip(rois, gs): # show_heatmap(heatmap, 'target') print(f'heatmap:{heatmap.shape}') heatmap = heatmap.unsqueeze(0) x1, y1, x2, y2 = map(int, roi) roi_heatmap = torch.zeros_like(heatmap) roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1] # show_heatmap(roi_heatmap[0],'roi_heatmap') all_roi_heatmap.append(roi_heatmap) all_roi_heatmap = torch.cat(all_roi_heatmap) print(f'all_roi_heatmap:{all_roi_heatmap.shape}') return all_roi_heatmap def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs): # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor N, K, H, W = line_logits.shape len_proposals = len(proposals) empty_count = 0 non_empty_count = 0 for prop in proposals: if prop.shape[0] == 0: empty_count += 1 else: non_empty_count += 1 print(f"Empty proposals count: {empty_count}") print(f"Non-empty proposals count: {non_empty_count}") print(f'starte to compute_point_loss') print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}') if H != W: raise ValueError( f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}" ) discretization_size = H gs_heatmaps = [] # print(f'point_matched_idxs:{point_matched_idxs}') for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs): print(f'proposals_per_image:{proposals_per_image.shape}') kp = gt_kp_in_image[midx] # print(f'gt_kp_in_image:{gt_kp_in_image}') gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size) gs_heatmaps.append(gs_heatmaps_per_img) gs_heatmaps = torch.cat(gs_heatmaps, dim=0) print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}') line_logits = line_logits[:,0] print(f'single_point_logits:{line_logits.shape}') line_loss = F.cross_entropy(line_logits, gs_heatmaps) return line_loss def lines_to_boxes(lines, img_size=511): """ 输入: lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y) img_size: int,图像尺寸,用于 clamp 边界 输出: boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max] """ # 提取所有线段的两个端点 p1 = lines[:, 0] # (N, 2) p2 = lines[:, 1] # (N, 2) # 每条线段的 x 和 y 坐标 x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1) # (N, 2) y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1) # (N, 2) # 计算包围盒边界 x_min = x_coords.min(dim=1).values y_min = y_coords.min(dim=1).values x_max = x_coords.max(dim=1).values y_max = y_coords.max(dim=1).values # 扩展边界并限制在图像范围内 x_min = (x_min - 1).clamp(min=0, max=img_size) y_min = (y_min - 1).clamp(min=0, max=img_size) x_max = (x_max + 1).clamp(min=0, max=img_size) y_max = (y_max + 1).clamp(min=0, max=img_size) # 合成包围盒 boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1) # (N, 4) return boxes def box_iou_pairwise(box1, box2): """ 输入: box1: shape (N, 4) box2: shape (M, 4) 输出: ious: shape (min(N, M), ), 只计算 i = j 的配对 """ N = min(len(box1), len(box2)) lt = torch.max(box1[:N, :2], box2[:N, :2]) # 左上角 rb = torch.min(box1[:N, 2:], box2[:N, 2:]) # 右下角 wh = (rb - lt).clamp(min=0) # 宽高 inter_area = wh[:, 0] * wh[:, 1] # 交集面积 area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1]) area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1]) union_area = area1 + area2 - inter_area ious = inter_area / (union_area + 1e-6) return ious def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0): """ Args: x: [N,1,H,W] 热力图 boxes: [N,4] 框坐标 gt_lines: [N,2,3] GT线段(含可见性) matched_idx: 匹配 index img_size: 图像尺寸 alpha: IoU 损失权重 beta: 长度损失权重 gamma: 方向角度损失权重 """ losses = [] boxes_per_image = [box.size(0) for box in boxes] x2 = x.split(boxes_per_image, dim=0) for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx): p_prob, _ = heatmaps_to_lines(xx, bb) pred_lines = p_prob gt_line_points = gt_line[mid] if len(pred_lines) == 0 or len(gt_line_points) == 0: continue # IoU 损失 pred_boxes = lines_to_boxes(pred_lines, img_size) gt_boxes = lines_to_boxes(gt_line_points, img_size) ious = box_iou_pairwise(pred_boxes, gt_boxes) iou_loss = 1.0 - ious # [N] # 长度损失 pred_len = line_length(pred_lines) gt_len = line_length(gt_line_points) length_diff = F.l1_loss(pred_len, gt_len, reduction='none') # [N] # 方向角度损失 pred_dir = line_direction(pred_lines) gt_dir = line_direction(gt_line_points) ang_loss = angle_loss_cosine(pred_dir, gt_dir) # [N] # 归一化每一项损失 norm_iou = normalize_tensor(iou_loss) norm_len = normalize_tensor(length_diff) norm_ang = normalize_tensor(ang_loss) total = alpha * norm_iou + beta * norm_len + gamma * norm_ang losses.append(total) if not losses: return None return torch.mean(torch.cat(losses)) def point_inference(x, point_boxes): # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] points_probs = [] points_scores = [] boxes_per_image = [box.size(0) for box in point_boxes] x2 = x.split(boxes_per_image, dim=0) for xx, bb in zip(x2, point_boxes): point_prob,point_scores = heatmaps_to_points(xx, bb) points_probs.append(point_prob.unsqueeze(1)) points_scores.append(point_scores) return points_probs,points_scores def line_inference(x, line_boxes): # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] lines_probs = [] lines_scores = [] boxes_per_image = [box.size(0) for box in line_boxes] x2 = x.split(boxes_per_image, dim=0) # x2:tuple 2 x2[0]:[1,3,1024,1024] # line_box: list:2 [1,4] [1.4] fasterrcnn kuang for xx, bb in zip(x2, line_boxes): line_prob, line_scores, = heatmaps_to_lines(xx, bb) lines_probs.append(line_prob) lines_scores.append(line_scores) return lines_probs, lines_scores def arc_inference(x, point_boxes): # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] points_probs = [] points_scores = [] boxes_per_image = [box.size(0) for box in point_boxes] x2 = x.split(boxes_per_image, dim=0) for xx, bb in zip(x2, point_boxes): point_prob,point_scores = heatmaps_to_arc(xx, bb) points_probs.append(point_prob.unsqueeze(1)) points_scores.append(point_scores) return points_probs,points_scores import torch.nn.functional as F import torch.nn.functional as F def heatmaps_to_arc(maps, rois, threshold=0.1, output_size=(128, 128)): """ Args: maps: [N, 3, H, W] - full heatmaps rois: [N, 4] - bounding boxes threshold: float - binarization threshold output_size: resized size for uniform NMS Returns: masks: [N, 1, H, W] - binary mask aligned with input map scores: [N, 1] - count of non-zero pixels in each mask """ N, _, H, W = maps.shape masks = torch.zeros((N, 1, H, W), dtype=torch.float32, device=maps.device) scores = torch.zeros((N, 1), dtype=torch.float32, device=maps.device) point_maps = maps[:, 0] # È¡µÚÒ»¸öͨµÀ [N, H, W] print(f"==> heatmaps_to_arc: maps.shape = {maps.shape}, rois.shape = {rois.shape}") for i in range(N): x1, y1, x2, y2 = rois[i].long() x1 = x1.clamp(0, W - 1) x2 = x2.clamp(0, W - 1) y1 = y1.clamp(0, H - 1) y2 = y2.clamp(0, H - 1) print(f"[{i}] roi: ({x1.item()}, {y1.item()}, {x2.item()}, {y2.item()})") if x2 <= x1 or y2 <= y1: print(f" Skipped invalid ROI at index {i}") continue roi_map = point_maps[i, y1:y2, x1:x2] # [h, w] print(f" roi_map.shape: {roi_map.shape}") if roi_map.numel() == 0: print(f" Skipped empty ROI at index {i}") continue # resize to uniform size roi_map_resized = F.interpolate( roi_map.unsqueeze(0).unsqueeze(0), size=output_size, mode='bilinear', align_corners=False ) # [1, 1, H, W] print(f" roi_map_resized.shape: {roi_map_resized.shape}") # NMS + threshold nms_roi = non_maximum_suppression(roi_map_resized) # shape: [1, H, W] bin_mask = (nms_roi > threshold).float() # shape: [1, H, W] print(f" bin_mask.sum(): {bin_mask.sum().item()}") # resize back to original roi size h = int((y2 - y1).item()) w = int((x2 - x1).item()) # È·±£ bin_mask ÊÇ [1, 128, 128] assert bin_mask.dim() == 4, f"Expected 3D tensor [1, H, W], got {bin_mask.shape}" # ÉϲÉÑù»Ø ROI ԭʼ´óС bin_mask_original_size = F.interpolate( # bin_mask.unsqueeze(0), # ? [1, 1, 128, 128] bin_mask, # ? [1, 1, 128, 128] size=(h, w), mode='bilinear', align_corners=False )[0] # ? [1, h, w] masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size.squeeze() scores[i] = bin_mask_original_size.sum() print(f" bin_mask_original_size.shape: {bin_mask_original_size.shape}, sum: {scores[i].item()}") print(f"==> Done. Total valid masks: {(scores > 0).sum().item()} / {N}") return masks, scores