import torch from matplotlib import pyplot as plt import torch.nn.functional as F from torch import nn from torch.cuda import device from utils.data_process.mask.show_mask import save_full_mask class DiceLoss(nn.Module): def __init__(self, smooth=1.): super(DiceLoss, self).__init__() self.smooth = smooth def forward(self, logits, targets): probs = torch.sigmoid(logits) probs = probs.view(-1) targets = targets.view(-1).float() intersection = (probs * targets).sum() dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth) return 1. - dice bce_loss = nn.BCEWithLogitsLoss() dice_loss = DiceLoss() def combined_loss(preds, targets, alpha=0.5): bce = bce_loss(preds, targets) d = dice_loss(preds, targets) return alpha * bce + (1 - alpha) * d def features_align(features, proposals, img_size): print(f'features_align features:{features.shape},proposals:{len(proposals)}') align_feat_list = [] for feat, proposals_per_img in zip(features, proposals): print(f'feature_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}') if proposals_per_img.shape[0]>0: feat = feat.unsqueeze(0) for proposal in proposals_per_img: align_feat = torch.zeros_like(feat) # print(f'align_feat:{align_feat.shape}') x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal) # 将每个proposal框内的部分赋值到align_feats对应位置 align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1] align_feat_list.append(align_feat) # print(f'align_feat_list:{align_feat_list}') if len(align_feat_list) > 0: feats_tensor = torch.cat(align_feat_list) print(f'align features :{feats_tensor.shape}') else: feats_tensor = None return feats_tensor def normalize_tensor(t): return (t - t.min()) / (t.max() - t.min() + 1e-6) def line_length(lines): """ 计算每条线段的长度 lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成 返回: [N] """ return torch.norm(lines[:, 1] - lines[:, 0], dim=-1) def line_direction(lines): """ 计算每条线段的单位方向向量 lines: [N, 2, 2] 返回: [N, 2] 单位方向向量 """ vec = lines[:, 1] - lines[:, 0] return F.normalize(vec, dim=-1) def angle_loss_cosine(pred_dir, gt_dir): """ 使用 cosine similarity 计算方向差异 pred_dir: [N, 2] gt_dir: [N, 2] 返回: [N] """ cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0) return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可 def line_length(lines): """ 计算每条线段的长度 lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成 返回: [N] """ return torch.norm(lines[:, 1] - lines[:, 0], dim=-1) def line_direction(lines): """ 计算每条线段的单位方向向量 lines: [N, 2, 2] 返回: [N, 2] 单位方向向量 """ vec = lines[:, 1] - lines[:, 0] return F.normalize(vec, dim=-1) def angle_loss_cosine(pred_dir, gt_dir): """ 使用 cosine similarity 计算方向差异 pred_dir: [N, 2] gt_dir: [N, 2] 返回: [N] """ cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0) return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可 def single_point_to_heatmap(keypoints, rois, heatmap_size): # type: (Tensor, Tensor, int) -> Tensor print(f'rois:{rois.shape}') print(f'heatmap_size:{heatmap_size}') print(f'keypoints.shape:{keypoints.shape}') # batch_size, num_keypoints, _ = keypoints.shape x = keypoints[..., 0].unsqueeze(1) y = keypoints[..., 1].unsqueeze(1) gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0) # show_heatmap(gs[0],'target') all_roi_heatmap = [] for roi, heatmap in zip(rois, gs): # show_heatmap(heatmap, 'target') # print(f'heatmap:{heatmap.shape}') heatmap = heatmap.unsqueeze(0) x1, y1, x2, y2 = map(int, roi) roi_heatmap = torch.zeros_like(heatmap) roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1] # show_heatmap(roi_heatmap[0],'roi_heatmap') all_roi_heatmap.append(roi_heatmap) all_roi_heatmap = torch.cat(all_roi_heatmap) print(f'all_roi_heatmap:{all_roi_heatmap.shape}') return all_roi_heatmap def arc_inference1(arc_equation, x, arc_boxes, th): # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] points_probs = [] points_scores = [] print(f'arc_boxes:{len(arc_boxes)}') boxes_per_image = [box.size(0) for box in arc_boxes] print(f'arc boxes_per_image:{boxes_per_image}') x2 = x.split(boxes_per_image, dim=0) arc7 = arc_equation.split(boxes_per_image, dim=0) # print(f'arc7:{arc7}') for xx, bb in zip(x2, arc_boxes): point_prob, point_scores = heatmaps_to_arc(xx, bb) points_probs.append(point_prob.unsqueeze(1)) points_scores.append(point_scores) return arc7, points_scores def points_to_heatmap(keypoints, rois,num_points=2, heatmap_size=(512,512)): # type: (Tensor, Tensor, int) -> Tensor print(f'rois:{rois.shape}') print(f'heatmap_size:{heatmap_size}') print(f'keypoints.shape:{keypoints.shape}') # batch_size, num_keypoints, _ = keypoints.shape x = keypoints[..., 0].unsqueeze(1) y = keypoints[..., 1].unsqueeze(1) gs = generate_gaussian_heatmaps(x, y,num_points=num_points, heatmap_size=heatmap_size, sigma=2.0) # show_heatmap(gs[0],'target') all_roi_heatmap = [] for roi, heatmap in zip(rois, gs): # show_heatmap(heatmap, 'target') # print(f'heatmap:{heatmap.shape}') heatmap = heatmap.unsqueeze(0) x1, y1, x2, y2 = map(int, roi) roi_heatmap = torch.zeros_like(heatmap) roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1] # show_heatmap(roi_heatmap[0],'roi_heatmap') all_roi_heatmap.append(roi_heatmap) all_roi_heatmap = torch.cat(all_roi_heatmap) print(f'all_roi_heatmap:{all_roi_heatmap.shape}') return all_roi_heatmap def line_points_to_heatmap(keypoints, rois, heatmap_size): # type: (Tensor, Tensor, int) -> Tensor print(f'rois:{rois.shape}') print(f'heatmap_size:{heatmap_size}') print(f'keypoints.shape:{keypoints.shape}') # batch_size, num_keypoints, _ = keypoints.shape x = keypoints[..., 0] y = keypoints[..., 1] gs = generate_gaussian_heatmaps(x, y, heatmap_size,num_points=2, sigma=1.0) # show_heatmap(gs[0],'target') all_roi_heatmap = [] for roi, heatmap in zip(rois, gs): # print(f'heatmap:{heatmap.shape}') # show_heatmap(heatmap,'target') heatmap = heatmap.unsqueeze(0) x1, y1, x2, y2 = map(int, roi) roi_heatmap = torch.zeros_like(heatmap) roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1] # show_heatmap(roi_heatmap[0],'roi_heatmap') all_roi_heatmap.append(roi_heatmap) if len(all_roi_heatmap) > 0: all_roi_heatmap = torch.cat(all_roi_heatmap) print(f'all_roi_heatmap:{all_roi_heatmap.shape}') else: all_roi_heatmap = None return all_roi_heatmap """ 修改适配的原结构的点 转热图,适用于带roi_pool版本的 """ def line_points_to_heatmap_(keypoints, rois, heatmap_size): # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor] print(f'rois:{rois.shape}') print(f'heatmap_size:{heatmap_size}') offset_x = rois[:, 0] offset_y = rois[:, 1] scale_x = heatmap_size / (rois[:, 2] - rois[:, 0]) scale_y = heatmap_size / (rois[:, 3] - rois[:, 1]) offset_x = offset_x[:, None] offset_y = offset_y[:, None] scale_x = scale_x[:, None] scale_y = scale_y[:, None] print(f'keypoints.shape:{keypoints.shape}') # batch_size, num_keypoints, _ = keypoints.shape x = keypoints[..., 0] y = keypoints[..., 1] # gs=generate_gaussian_heatmaps(x,y,512,1.0) # print(f'gs_heatmap shape:{gs.shape}') # # show_heatmap(gs[0],'target') x_boundary_inds = x == rois[:, 2][:, None] y_boundary_inds = y == rois[:, 3][:, None] x = (x - offset_x) * scale_x x = x.floor().long() y = (y - offset_y) * scale_y y = y.floor().long() x[x_boundary_inds] = heatmap_size - 1 y[y_boundary_inds] = heatmap_size - 1 # print(f'heatmaps x:{x}') # print(f'heatmaps y:{y}') valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size) vis = keypoints[..., 2] > 0 valid = (valid_loc & vis).long() gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0) # show_heatmap(gs_heatmap[0], 'feature') # print(f'gs_heatmap:{gs_heatmap.shape}') # # lin_ind = y * heatmap_size + x # print(f'lin_ind:{lin_ind.shape}') # heatmaps = lin_ind * valid return gs_heatmap def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'): """ 为一组点生成并合并高斯热图。 Args: xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标 ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标 heatmap_size (int): 热图大小 H=W sigma (float): 高斯核标准差 device (str): 设备类型 ('cpu' or 'cuda') Returns: Tensor: 形状为 (H, W) 的合并后的热图 """ assert xs.shape == ys.shape, "x and y must have the same shape" print(f'xs:{xs.shape}') xs=xs.squeeze(1) ys = ys.squeeze(1) print(f'xs1:{xs.shape}') N = xs.shape[0] print(f'N:{N},num_points:{num_points}') # 创建网格 grid_y, grid_x = torch.meshgrid( torch.arange(heatmap_size, device=device), torch.arange(heatmap_size, device=device), indexing='ij' ) # print(f'heatmap_size:{heatmap_size}') # 初始化输出热图 combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device) for i in range(N): heatmap= torch.zeros((heatmap_size, heatmap_size), device=device) for j in range(num_points): mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item() mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item() # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}') # 计算距离平方 dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2 # 计算高斯分布 heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2)) heatmap+=heatmap1 # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item() # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item() # # # 计算距离平方 # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2 # # # 计算高斯分布 # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2)) # # heatmap = heatmap1 + heatmap2 # 将当前热图累加到结果中 combined_heatmap[i] = heatmap return combined_heatmap def generate_mask_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'): """ 为一组点生成并合并高斯热图。 Args: xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标 ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标 heatmap_size (int): 热图大小 H=W sigma (float): 高斯核标准差 device (str): 设备类型 ('cpu' or 'cuda') Returns: Tensor: 形状为 (H, W) 的合并后的热图 """ assert xs.shape == ys.shape, "x and y must have the same shape" print(f'xs:{xs.shape}') xs=xs.squeeze(1) ys = ys.squeeze(1) print(f'xs1:{xs.shape}') N = xs.shape[0] print(f'N:{N},num_points:{num_points}') # 创建网格 grid_y, grid_x = torch.meshgrid( torch.arange(heatmap_size, device=device), torch.arange(heatmap_size, device=device), indexing='ij' ) # print(f'heatmap_size:{heatmap_size}') # 初始化输出热图 combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device) for i in range(N): heatmap= torch.zeros((heatmap_size, heatmap_size), device=device) for j in range(num_points): mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item() mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item() # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}') # 计算距离平方 dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2 # 计算高斯分布 heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2)) heatmap+=heatmap1 # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item() # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item() # # # 计算距离平方 # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2 # # # 计算高斯分布 # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2)) # # heatmap = heatmap1 + heatmap2 # 将当前热图累加到结果中 combined_heatmap[i] = heatmap return combined_heatmap def non_maximum_suppression(a): ap = F.max_pool2d(a, 3, stride=1, padding=1) mask = (a == ap).float().clamp(min=0.0) return a * mask def heatmaps_to_points(maps, rois,num_points=2): point_preds = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device) point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device) print(f'heatmaps_to_lines:{maps.shape}') point_maps=maps[:,0] print(f'point_map:{point_maps.shape}') for i in range(len(rois)): point_roi_map = point_maps[i].unsqueeze(0) print(f'point_roi_map:{point_roi_map.shape}') # roi_map_probs = scores_to_probs(roi_map.copy()) w = point_roi_map.shape[2] flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1) point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points) print(f'point index:{point_index}') # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) point_x =point_index % w point_y = torch.div(point_index - point_x, w, rounding_mode="floor") point_preds[i, 0,] = point_x point_preds[i, 1,] = point_y point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x] return point_preds,point_end_scores # 分4块 def find_max_heat_point_in_each_part(feature_map, box): """ 在给定的特征图上,根据box中心点往上移3,往右移3作为新的中心点, 并将特征图划分为4个部分,之后在每个部分中找到热度值最大的点。 Args: feature_map (torch.Tensor): 形状为 [C, H, W] 的特征图 box (torch.Tensor): 形状为 [4] 的边界框 [x_min, y_min, x_max, y_max] Returns: list: 每个区域中热度最高的点的位置和其对应的热度值 [(y1, x1, heat1), ..., (y4, x4, heat4)] """ device = feature_map.device C, H, W = feature_map.shape # 计算box的中心点(cx, cy) cx = (box[0] + box[2]) // 2 cy = (box[1] + box[3]) // 2 # 偏移中心点 new_cx = min(max(cx + 3, 0), W - 1) # 向右移3 new_cy = min(max(cy - 3, 0), H - 1) # 向上移3 # 创建坐标网格 y_coords, x_coords = torch.meshgrid( torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij' ) # 划分四个区域 mask_q1 = (y_coords < new_cy) & (x_coords < new_cx) # 左上 mask_q2 = (y_coords < new_cy) & (x_coords >= new_cx) # 右上 mask_q3 = (y_coords >= new_cy) & (x_coords < new_cx) # 左下 mask_q4 = (y_coords >= new_cy) & (x_coords >= new_cx) # 右下 # def process_region(ins): # region = feature_map[:, :, ins].squeeze() # if len(region.shape) == 0: # 如果区域为空,则跳过 # return None, None # # 找到最大热度值的点及其位置 # (y, x), heat_val = non_maximum_suppression(region[0]) # # 将相对坐标转换回全局坐标 # y_global = y + torch.where(ins)[0].min().item() # x_global = x + torch.where(ins)[1].min().item() # return (y_global, x_global), heat_val # # results = [] # for ins in [mask_q1, mask_q2, mask_q3, mask_q4]: # point, heat_val = process_region(ins) # if point is not None: # # results.append((point[0], point[1], heat_val)) # results.append((point[0], point[1])) # else: # results.append(None) masks = [mask_q1, mask_q2, mask_q3, mask_q4] results = [] # 假设使用第一个通道作为热力图 heatmap = feature_map[0] # [H, W] def process_region(mask): # 应用 ins,只保留该区域 masked_heatmap = heatmap.clone() # 复制以避免修改原数据 masked_heatmap[~mask] = 0 # 非区域置0 def non_maximum_suppression_2d(heatmap, kernel_size=3): """ 对 2D 热力图做非极大值抑制,保留局部最大值点。 Args: heatmap (torch.Tensor): [H, W],输入热力图 kernel_size (int): 池化窗口大小,用于比较是否为局部最大值 Returns: torch.Tensor: 与 heatmap 同形状的 ins,局部最大值位置为 True """ pad = (kernel_size - 1) // 2 max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad) maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0) # 局部最大值且值大于0 peaks = (heatmap == maxima) & (heatmap > 0) return peaks # 1. 先做 NMS 得到候选局部极大值点 nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3) # [H, W] bool candidate_peaks = masked_heatmap * nms_mask.float() # 只保留 NMS 后的峰值 # 2. 找出所有候选点中值最大的一个 if candidate_peaks.max() <= 0: return None # 找到最大值的位置 max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0) y, x = divmod(max_idx.item(), W) return (x, y) # 返回 (y, x) for mask in masks: point = process_region(mask) results.append(point) return results def non_maximum_suppression_2d(heatmap, kernel_size=3): pad = (kernel_size - 1) // 2 max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad) maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0) peaks = (heatmap == maxima) & (heatmap > 0) return peaks def find_max_heat_point_in_edge_centers(feature_map, box): device = feature_map.device C, H, W = feature_map.shape # ¼ÆËã box ÖÐÐÄ cx = (box[0] + box[2]) / 2 cy = (box[1] + box[3]) / 2 # ¸ù¾Ý box ¿í¸ß¼ÆËã¾Å¹¬¸ñ·Ö½çÏß box_width = box[2] - box[0] box_height = box[3] - box[1] x_left = cx - box_width / 6 x_right = cx + box_width / 6 y_top = cy - box_height / 6 y_bottom = cy + box_height / 6 # ´´½¨Íø¸ñ y_coords, x_coords = torch.meshgrid( torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij' ) # ¶¨ÒåËĸö¡°±ßÖС±ÇøÓòµÄ ins mask1 = (x_coords < x_left) & (y_coords < y_top) mask_top_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords < y_top) mask3 = (x_coords >= x_right) & (y_coords < y_top) mask_left_middle = (x_coords < x_left) & (y_coords >= y_top) & (y_coords < y_bottom) mask_right_middle = (x_coords >= x_right) & (y_coords >= y_top) & (y_coords < y_bottom) mask4 = (x_coords < x_left) & (y_coords >= y_bottom) mask_bottom_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords >= y_bottom) mask_right_bottom = (x_coords >= x_right) & (y_coords >= y_bottom) # masks = [ # # mask1, # mask_top_middle, # # mask3, # mask_left_middle, # mask_right_middle, # # mask4, # mask_bottom_middle, # mask_right_bottom # ] masks = [ mask_top_middle, mask_right_middle, mask_bottom_middle, mask_left_middle ] # ʹÓõÚÒ»¸öͨµÀ×÷ΪÈÈÁ¦Í¼ heatmap = feature_map[0] # [H, W] results = [] for mask in masks: masked_heatmap = heatmap.clone() masked_heatmap[~mask] = 0 # ·ÇÄ¿±êÇøÓòÖà 0 # # NMS ÒÖÖÆ # nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3) # candidate_peaks = masked_heatmap * nms_mask.float() # # if candidate_peaks.max() <= 0: # results.append(None) # continue # # # ÕÒ×î´óֵλÖà # max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0) # y, x = divmod(max_idx.item(), W) flatten_point_roi_map = masked_heatmap.reshape(1, -1) point_score, point_index = torch.topk(flatten_point_roi_map, k=1) point_x =point_index % W point_y = torch.div(point_index - point_x, W, rounding_mode="floor") results.append((point_x, point_y)) return results # [(y_top, x_top), (y_right, x_right), (y_bottom, x_bottom), (y_left, x_left)] def heatmaps_to_circle_points(maps, rois,num_points=2): point_preds = torch.zeros((len(rois), 4, 2), dtype=torch.float32, device=maps.device) point_end_scores = torch.zeros((len(rois),4, 1), dtype=torch.float32, device=maps.device) print(f'rois in heatmaps_to_circle_points:{type(rois), rois.shape}') # print(f'heatmaps_to_lines:{maps.shape}') point_maps=maps[:,0] print(f'point_map:{point_maps.shape}') for i in range(len(rois)): point_roi_map = point_maps[i].unsqueeze(0) print(f'point_roi_map:{point_roi_map.shape}') # roi_map_probs = scores_to_probs(roi_map.copy()) # w = point_roi_map.shape[2] # flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1) # print(f'non_maximum_suppression :{non_maximum_suppression(point_roi_map).shape}') # point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points) # print(f'point index:{point_index}') # point_x =point_index % w # point_y = torch.div(point_index - point_x, w, rounding_mode="floor") # print(f'point_x:{point_x}, point_y:{point_y}') # point_preds[i, :, 0] = point_x # point_preds[i, :, 1] = point_y roi1=rois[i] result_points = find_max_heat_point_in_edge_centers(non_maximum_suppression(point_roi_map), roi1) point_preds[i, :]=torch.tensor(result_points) point_x = [point[0] for point in result_points] point_y = [point[1] for point in result_points] point_end_scores[i, :,0] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x] return point_preds,point_end_scores def heatmaps_to_lines(maps, rois): line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device) line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device) line_maps=maps[:,1] # line_maps = maps.squeeze(1) for i in range(len(rois)): line_roi_map = line_maps[i].unsqueeze(0) print(f'line_roi_map:{line_roi_map.shape}') # roi_map_probs = scores_to_probs(roi_map.copy()) w = line_roi_map.shape[1] flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1) line_score, line_index = torch.topk(flatten_line_roi_map, k=2) print(f'line index:{line_index}') # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) pos = line_index line_x = pos % w line_y = torch.div(pos - line_x, w, rounding_mode="floor") line_preds[i, 0, :] = line_x line_preds[i, 1, :] = line_y line_preds[i, 2, :] = 1 line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x] return line_preds.permute(0, 2, 1), line_end_scores # 显示热图的函数 def show_heatmap(heatmap, title="Heatmap"): """ 使用 matplotlib 显示热图。 Args: heatmap (Tensor): 要显示的热图张量 title (str): 图表标题 """ # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组 if heatmap.is_cuda: heatmap = heatmap.cpu().numpy() else: heatmap = heatmap.numpy() plt.imshow(heatmap, cmap='hot', interpolation='nearest') plt.colorbar() plt.title(title) plt.show() def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs): # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor N, K, H, W = line_logits.shape len_proposals = len(proposals) print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals},line_matched_idxs:{line_matched_idxs}') if H != W: raise ValueError( f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}" ) discretization_size = H heatmaps = [] gs_heatmaps = [] valid = [] for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs): print(f'line_proposals_per_image:{proposals_per_image.shape}') print(f'gt_lines:{gt_lines}') if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0: kp = gt_kp_in_image[midx] gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size) gs_heatmaps.append(gs_heatmaps_per_img) # print(f'heatmaps_per_image:{heatmaps_per_image.shape}') # heatmaps.append(heatmaps_per_image.view(-1)) # valid.append(valid_per_image.view(-1)) # line_targets = torch.cat(heatmaps, dim=0) gs_heatmaps = torch.cat(gs_heatmaps, dim=0) print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}') # print(f'line_targets:{line_targets.shape},{line_targets}') # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8) # valid = torch.where(valid)[0] # print(f' line_targets[valid]:{line_targets[valid]}') # torch.mean (in binary_cross_entropy_with_logits) doesn't # accept empty tensors, so handle it sepaartely # if line_targets.numel() == 0 or len(valid) == 0: # return line_logits.sum() * 0 # line_logits = line_logits.view(N * K, H * W) # print(f'line_logits[valid]:{line_logits[valid].shape}') print(f'loss1 line_logits:{line_logits.shape}') line_logits = line_logits[:,1,:,:] # line_logits = line_logits.squeeze(1) print(f'loss2 line_logits:{line_logits.shape}') # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid]) line_loss = F.cross_entropy(line_logits, gs_heatmaps) return line_loss def compute_ins_loss(feature_logits, proposals, gt_, pos_matched_idxs): print(f'compute_arc_loss:{feature_logits.shape}') N, K, H, W = feature_logits.shape len_proposals = len(proposals) empty_count = 0 non_empty_count = 0 for prop in proposals: if prop.shape[0] == 0: empty_count += 1 else: non_empty_count += 1 print(f"Empty proposals count: {empty_count}") print(f"Non-empty proposals count: {non_empty_count}") print(f'starte to compute_point_loss') print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}') if H != W: raise ValueError( f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}" ) discretization_size = H gs_heatmaps = [] # print(f'point_matched_idxs:{point_matched_idxs}') print(f'gt_masks:{gt_[0].shape}') for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs): # [ # (Tensor(38, 4), Tensor(1, 57, 2), Tensor(38, 1)), # (Tensor(65, 4), Tensor(1, 74, 2), Tensor(65, 1)) # ] print(f'proposals_per_image:{proposals_per_image.shape}') kp = gt_kp_in_image[midx] t_h, t_w = kp.shape[-2:] print(f't_h:{t_h}, t_w:{t_w}') print(f'gt_kp_in_image:{gt_kp_in_image.shape}') if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0: gs_heatmaps_per_img = align_masks(kp, proposals_per_image, discretization_size) gs_heatmaps.append(gs_heatmaps_per_img) if len(gs_heatmaps)>0: gs_heatmaps = torch.cat(gs_heatmaps, dim=0) print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}') line_logits = feature_logits.squeeze(1) print(f'ins shape:{line_logits.shape}') # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps) # line_loss = F.cross_entropy(line_logits, gs_heatmaps) # save_full_mask(line_logits,"line_logits",out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_loss") # save_full_mask(gs_heatmaps,"gs_heatmaps",out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_loss") line_loss=combined_loss(line_logits, gs_heatmaps) else: line_loss=100 print("d") return line_loss def align_masks(keypoints, rois, heatmap_size): print(f'rois:{rois.shape}') print(f'heatmap_size:{heatmap_size}') print(f'keypoints.shape:{keypoints.shape}') # t_h, t_w = keypoints.shape[-2:] # scale=heatmap_size/t_w # print(f'scale:{scale}') # x = keypoints[..., 0]*scale # y = keypoints[..., 1]*scale # # x = x.unsqueeze(1) # y = y.unsqueeze(1) # # num_points=x.shape[2] # print(f'num_points:{num_points}') # plt.imshow(keypoints[0].cpu()) # plt.show() mask_4d = keypoints.unsqueeze(1).float() resized_mask = F.interpolate( mask_4d, size = (heatmap_size, heatmap_size), mode = 'bilinear', align_corners = False ).squeeze(1) # [B,heatmap_size,heatmap_size] # plt.imshow(resized_mask[0].cpu()) # plt.show() print(f'resized_mask:{resized_mask.shape}') return resized_mask def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs): # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor N, K, H, W = line_logits.shape len_proposals = len(proposals) empty_count = 0 non_empty_count = 0 for prop in proposals: if prop.shape[0] == 0: empty_count += 1 else: non_empty_count += 1 print(f"Empty proposals count: {empty_count}") print(f"Non-empty proposals count: {non_empty_count}") print(f'starte to compute_point_loss') print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}') if H != W: raise ValueError( f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}" ) discretization_size = H gs_heatmaps = [] # print(f'point_matched_idxs:{point_matched_idxs}') for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs): print(f'proposals_per_image:{proposals_per_image.shape}') kp = gt_kp_in_image[midx] # print(f'gt_kp_in_image:{gt_kp_in_image}') gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size) gs_heatmaps.append(gs_heatmaps_per_img) gs_heatmaps = torch.cat(gs_heatmaps, dim=0) print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}') line_logits = line_logits[:,0] print(f'single_point_logits:{line_logits.shape}') line_loss = F.cross_entropy(line_logits, gs_heatmaps) return line_loss def compute_circle_loss(circle_logits, proposals, gt_circles, circle_matched_idxs): # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor N, K, H, W = circle_logits.shape len_proposals = len(proposals) empty_count = 0 non_empty_count = 0 for prop in proposals: if prop.shape[0] == 0: empty_count += 1 else: non_empty_count += 1 print(f"Empty proposals count: {empty_count}") print(f"Non-empty proposals count: {non_empty_count}") print(f'starte to compute_circle_loss') print(f'compute_circle_loss circle_logits.shape:{circle_logits.shape},len_proposals:{len_proposals}') if H != W: raise ValueError( f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}" ) discretization_size = H gs_heatmaps = [] # print(f'point_matched_idxs:{point_matched_idxs}') for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_circles, circle_matched_idxs): print(f'proposals_per_image:{proposals_per_image.shape}') kp = gt_kp_in_image[midx] # print(f'gt_kp_in_image:{gt_kp_in_image}') gs_heatmaps_per_img = points_to_heatmap(kp, proposals_per_image,num_points=4, heatmap_size=discretization_size) gs_heatmaps.append(gs_heatmaps_per_img) gs_heatmaps = torch.cat(gs_heatmaps, dim=0) print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{circle_logits.squeeze(1).shape}') circle_logits = circle_logits[:, 0] print(f'circle_logits:{circle_logits.shape}') circle_loss = F.cross_entropy(circle_logits, gs_heatmaps) return circle_loss def lines_to_boxes(lines, img_size=511): """ 输入: lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y) img_size: int,图像尺寸,用于 clamp 边界 输出: boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max] """ # 提取所有线段的两个端点 p1 = lines[:, 0] # (N, 2) p2 = lines[:, 1] # (N, 2) # 每条线段的 x 和 y 坐标 x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1) # (N, 2) y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1) # (N, 2) # 计算包围盒边界 x_min = x_coords.min(dim=1).values y_min = y_coords.min(dim=1).values x_max = x_coords.max(dim=1).values y_max = y_coords.max(dim=1).values # 扩展边界并限制在图像范围内 x_min = (x_min - 1).clamp(min=0, max=img_size) y_min = (y_min - 1).clamp(min=0, max=img_size) x_max = (x_max + 1).clamp(min=0, max=img_size) y_max = (y_max + 1).clamp(min=0, max=img_size) # 合成包围盒 boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1) # (N, 4) return boxes def box_iou_pairwise(box1, box2): """ 输入: box1: shape (N, 4) box2: shape (M, 4) 输出: ious: shape (min(N, M), ), 只计算 i = j 的配对 """ N = min(len(box1), len(box2)) lt = torch.max(box1[:N, :2], box2[:N, :2]) # 左上角 rb = torch.min(box1[:N, 2:], box2[:N, 2:]) # 右下角 wh = (rb - lt).clamp(min=0) # 宽高 inter_area = wh[:, 0] * wh[:, 1] # 交集面积 area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1]) area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1]) union_area = area1 + area2 - inter_area ious = inter_area / (union_area + 1e-6) return ious def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0): """ Args: x: [N,1,H,W] 热力图 boxes: [N,4] 框坐标 gt_lines: [N,2,3] GT线段(含可见性) matched_idx: 匹配 index img_size: 图像尺寸 alpha: IoU 损失权重 beta: 长度损失权重 gamma: 方向角度损失权重 """ losses = [] boxes_per_image = [box.size(0) for box in boxes] x2 = x.split(boxes_per_image, dim=0) for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx): p_prob, _ = heatmaps_to_lines(xx, bb) pred_lines = p_prob gt_line_points = gt_line[mid] if len(pred_lines) == 0 or len(gt_line_points) == 0: continue # IoU 损失 pred_boxes = lines_to_boxes(pred_lines, img_size) gt_boxes = lines_to_boxes(gt_line_points, img_size) ious = box_iou_pairwise(pred_boxes, gt_boxes) iou_loss = 1.0 - ious # [N] # 长度损失 pred_len = line_length(pred_lines) gt_len = line_length(gt_line_points) length_diff = F.l1_loss(pred_len, gt_len, reduction='none') # [N] # 方向角度损失 pred_dir = line_direction(pred_lines) gt_dir = line_direction(gt_line_points) ang_loss = angle_loss_cosine(pred_dir, gt_dir) # [N] # 归一化每一项损失 norm_iou = normalize_tensor(iou_loss) norm_len = normalize_tensor(length_diff) norm_ang = normalize_tensor(ang_loss) total = alpha * norm_iou + beta * norm_len + gamma * norm_ang losses.append(total) if not losses: return None return torch.mean(torch.cat(losses)) def point_inference(x, point_boxes): # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] points_probs = [] points_scores = [] boxes_per_image = [box.size(0) for box in point_boxes] x2 = x.split(boxes_per_image, dim=0) for xx, bb in zip(x2, point_boxes): point_prob,point_scores = heatmaps_to_points(xx, bb,num_points=1) points_probs.append(point_prob.unsqueeze(1)) points_scores.append(point_scores) return points_probs,points_scores def circle_inference(x, point_boxes): # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] points_probs = [] points_scores = [] boxes_per_image = [box.size(0) for box in point_boxes] x2 = x.split(boxes_per_image, dim=0) for xx, bb in zip(x2, point_boxes): point_prob,point_scores = heatmaps_to_circle_points(xx, bb,num_points=4) points_probs.append(point_prob.unsqueeze(1)) points_scores.append(point_scores) return points_probs,points_scores def line_inference(x, line_boxes): # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] lines_probs = [] lines_scores = [] boxes_per_image = [box.size(0) for box in line_boxes] x2 = x.split(boxes_per_image, dim=0) # x2:tuple 2 x2[0]:[1,3,1024,1024] # line_box: list:2 [1,4] [1.4] fasterrcnn kuang for xx, bb in zip(x2, line_boxes): line_prob, line_scores, = heatmaps_to_lines(xx, bb) lines_probs.append(line_prob) lines_scores.append(line_scores) return lines_probs, lines_scores def ins_inference(x, arc_boxes, th): # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] points_probs = [] points_scores = [] print(f'arc_boxes:{len(arc_boxes)}') boxes_per_image = [box.size(0) for box in arc_boxes] print(f'arc boxes_per_image:{boxes_per_image}') x2 = x.split(boxes_per_image, dim=0) for xx, bb in zip(x2, arc_boxes): point_prob,point_scores = heatmaps_to_arc(xx, bb) points_probs.append(point_prob.unsqueeze(1)) points_scores.append(point_scores) points_probs_tensor=torch.cat(points_probs) print(f'points_probs shape:{points_probs_tensor.shape}') feature_logits = x batch_size = feature_logits.shape[0] num_proposals = len(arc_boxes[0]) results = [[torch.empty(0, 2) for _ in range(num_proposals)] for _ in range(batch_size)] proposals_list = arc_boxes[0] # [[tensor(...)]] for proposal_idx, proposal in enumerate(proposals_list): coords = proposal.tolist() x1, y1, x2, y2 = map(int, coords) x1 = max(0, x1) y1 = max(0, y1) x2 = min(feature_logits.shape[3], x2) y2 = min(feature_logits.shape[2], y2) for batch_idx in range(batch_size): region = feature_logits[batch_idx, :, y1:y2, x1:x2] mask = region > th coords = torch.nonzero(mask) if coords.numel() > 0: # 取 (y, x),然后转换为全局坐标 (x, y) local_coords = coords[:, [2, 1]] # (x, y) local_coords[:, 0] += x1 local_coords[:, 1] += y1 results[batch_idx][proposal_idx] = local_coords print(f're:{results}') return points_probs,points_scores,results import torch.nn.functional as F def heatmaps_to_arc(maps, rois, threshold=0.5, output_size=(128, 128)): """ Args: maps: [N, 3, H, W] - full heatmaps rois: [N, 4] - bounding boxes threshold: float - binarization threshold output_size: resized size for uniform NMS Returns: masks: [N, 1, H, W] - binary ins aligned with input map scores: [N, 1] - count of non-zero pixels in each ins """ N, _, H, W = maps.shape masks = torch.zeros((N, 1, H, W), dtype=torch.float32, device=maps.device) scores = torch.zeros((N, 1), dtype=torch.float32, device=maps.device) point_maps = maps[:, 0] # È¡µÚÒ»¸öͨµÀ [N, H, W] print(f"==> heatmaps_to_arc: maps.shape = {maps.shape}, rois.shape = {rois.shape}") for i in range(N): x1, y1, x2, y2 = rois[i].long() x1 = x1.clamp(0, W - 1) x2 = x2.clamp(0, W - 1) y1 = y1.clamp(0, H - 1) y2 = y2.clamp(0, H - 1) print(f"[{i}] roi: ({x1.item()}, {y1.item()}, {x2.item()}, {y2.item()})") if x2 <= x1 or y2 <= y1: print(f" Skipped invalid ROI at index {i}") continue roi_map = point_maps[i, y1:y2, x1:x2] # [h, w] print(f" roi_map.shape: {roi_map.shape}") if roi_map.numel() == 0: print(f" Skipped empty ROI at index {i}") continue # resize to uniform size roi_map_resized = F.interpolate( roi_map.unsqueeze(0).unsqueeze(0), size=output_size, mode='bilinear', align_corners=False ) # [1, 1, H, W] print(f" roi_map_resized.shape: {roi_map_resized.shape}") # NMS + threshold # nms_roi = non_maximum_suppression(roi_map_resized) # shape: [1, H, W] nms_roi = torch.sigmoid(roi_map_resized) bin_mask = (nms_roi >= threshold).float() # shape: [1, H, W] print(f" bin_mask.sum(): {bin_mask.sum().item()}") # resize back to original roi size h = int((y2 - y1).item()) w = int((x2 - x1).item()) # È·±£ bin_mask ÊÇ [1, 128, 128] assert bin_mask.dim() == 4, f"Expected 3D tensor [1, H, W], got {bin_mask.shape}" bin_mask_original_size = F.interpolate( bin_mask, # [1, 1, 128, 128] size=(h, w), mode='bilinear', align_corners=False )[0] # [1, h, w] bin_mask_original_size = bin_mask_original_size[0] # [h, w] masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size scores[i] = bin_mask_original_size.sum() # # bin_mask_original_size = F.interpolate( # # bin_mask.unsqueeze(0), # ? [1, 1, 128, 128] # bin_mask, # ? [1, 1, 128, 128] # size=(h, w), # mode='bilinear', # align_corners=False # )[0] # ? [1, h, w] # # masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size.squeeze() # scores[i] = bin_mask_original_size.sum() # plt.figure(figsize=(6, 6)) # plt.imshow(masks[i, 0].cpu().numpy(), cmap='gray') # plt.title(f"Mask {i}, score={scores[i].item():.1f}") # plt.axis('off') # plt.show() print(f" bin_mask_original_size.shape: {bin_mask_original_size.shape}, sum: {scores[i].item()}") print(f"==> Done. Total valid masks: {(scores > 0).sum().item()} / {N}") return masks, scores