| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322 |
- import torch
- from matplotlib import pyplot as plt
- import torch.nn.functional as F
- from torch import nn
- from torch.cuda import device
- from utils.data_process.mask.show_mask import save_full_mask
- class DiceLoss(nn.Module):
- def __init__(self, smooth=1.):
- super(DiceLoss, self).__init__()
- self.smooth = smooth
- def forward(self, logits, targets):
- probs = torch.sigmoid(logits)
- probs = probs.view(-1)
- targets = targets.view(-1).float()
- intersection = (probs * targets).sum()
- dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
- return 1. - dice
- bce_loss = nn.BCEWithLogitsLoss()
- dice_loss = DiceLoss()
- def combined_loss(preds, targets, alpha=0.5):
- bce = bce_loss(preds, targets)
- d = dice_loss(preds, targets)
- return alpha * bce + (1 - alpha) * d
- def features_align(features, proposals, img_size):
- print(f'features_align features:{features.shape},proposals:{len(proposals)}')
- align_feat_list = []
- for feat, proposals_per_img in zip(features, proposals):
- print(f'feature_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
- if proposals_per_img.shape[0]>0:
- feat = feat.unsqueeze(0)
- for proposal in proposals_per_img:
- align_feat = torch.zeros_like(feat)
- # print(f'align_feat:{align_feat.shape}')
- x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
- # 将每个proposal框内的部分赋值到align_feats对应位置
- align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
- align_feat_list.append(align_feat)
- # print(f'align_feat_list:{align_feat_list}')
- if len(align_feat_list) > 0:
- feats_tensor = torch.cat(align_feat_list)
- print(f'align features :{feats_tensor.shape}')
- else:
- feats_tensor = None
- return feats_tensor
- def normalize_tensor(t):
- return (t - t.min()) / (t.max() - t.min() + 1e-6)
- def line_length(lines):
- """
- 计算每条线段的长度
- lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
- 返回: [N]
- """
- return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
- def line_direction(lines):
- """
- 计算每条线段的单位方向向量
- lines: [N, 2, 2]
- 返回: [N, 2] 单位方向向量
- """
- vec = lines[:, 1] - lines[:, 0]
- return F.normalize(vec, dim=-1)
- def angle_loss_cosine(pred_dir, gt_dir):
- """
- 使用 cosine similarity 计算方向差异
- pred_dir: [N, 2]
- gt_dir: [N, 2]
- 返回: [N]
- """
- cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
- return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
- def line_length(lines):
- """
- 计算每条线段的长度
- lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
- 返回: [N]
- """
- return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
- def line_direction(lines):
- """
- 计算每条线段的单位方向向量
- lines: [N, 2, 2]
- 返回: [N, 2] 单位方向向量
- """
- vec = lines[:, 1] - lines[:, 0]
- return F.normalize(vec, dim=-1)
- def angle_loss_cosine(pred_dir, gt_dir):
- """
- 使用 cosine similarity 计算方向差异
- pred_dir: [N, 2]
- gt_dir: [N, 2]
- 返回: [N]
- """
- cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
- return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
- def single_point_to_heatmap(keypoints, rois, heatmap_size):
- # type: (Tensor, Tensor, int) -> Tensor
- print(f'rois:{rois.shape}')
- print(f'heatmap_size:{heatmap_size}')
- print(f'keypoints.shape:{keypoints.shape}')
- # batch_size, num_keypoints, _ = keypoints.shape
- x = keypoints[..., 0].unsqueeze(1)
- y = keypoints[..., 1].unsqueeze(1)
- gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
- # show_heatmap(gs[0],'target')
- all_roi_heatmap = []
- for roi, heatmap in zip(rois, gs):
- # show_heatmap(heatmap, 'target')
- # print(f'heatmap:{heatmap.shape}')
- heatmap = heatmap.unsqueeze(0)
- x1, y1, x2, y2 = map(int, roi)
- roi_heatmap = torch.zeros_like(heatmap)
- roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
- # show_heatmap(roi_heatmap[0],'roi_heatmap')
- all_roi_heatmap.append(roi_heatmap)
- all_roi_heatmap = torch.cat(all_roi_heatmap)
- print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
- return all_roi_heatmap
- def arc_inference1(arc_equation, x, arc_boxes, th):
- # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- points_probs = []
- points_scores = []
- print(f'arc_boxes:{len(arc_boxes)}')
- boxes_per_image = [box.size(0) for box in arc_boxes]
- print(f'arc boxes_per_image:{boxes_per_image}')
- x2 = x.split(boxes_per_image, dim=0)
- arc7 = arc_equation.split(boxes_per_image, dim=0)
- # print(f'arc7:{arc7}')
- for xx, bb in zip(x2, arc_boxes):
- point_prob, point_scores = heatmaps_to_arc(xx, bb)
- points_probs.append(point_prob.unsqueeze(1))
- points_scores.append(point_scores)
- return arc7, points_scores
- def points_to_heatmap(keypoints, rois,num_points=2, heatmap_size=(512,512)):
- # type: (Tensor, Tensor, int) -> Tensor
- print(f'rois:{rois.shape}')
- print(f'heatmap_size:{heatmap_size}')
- print(f'keypoints.shape:{keypoints.shape}')
- # batch_size, num_keypoints, _ = keypoints.shape
- x = keypoints[..., 0].unsqueeze(1)
- y = keypoints[..., 1].unsqueeze(1)
- gs = generate_gaussian_heatmaps(x, y,num_points=num_points, heatmap_size=heatmap_size, sigma=2.0)
- # show_heatmap(gs[0],'target')
- all_roi_heatmap = []
- for roi, heatmap in zip(rois, gs):
- # show_heatmap(heatmap, 'target')
- # print(f'heatmap:{heatmap.shape}')
- heatmap = heatmap.unsqueeze(0)
- x1, y1, x2, y2 = map(int, roi)
- roi_heatmap = torch.zeros_like(heatmap)
- roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
- # show_heatmap(roi_heatmap[0],'roi_heatmap')
- all_roi_heatmap.append(roi_heatmap)
- all_roi_heatmap = torch.cat(all_roi_heatmap)
- print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
- return all_roi_heatmap
- def line_points_to_heatmap(keypoints, rois, heatmap_size):
- # type: (Tensor, Tensor, int) -> Tensor
- print(f'rois:{rois.shape}')
- print(f'heatmap_size:{heatmap_size}')
- print(f'keypoints.shape:{keypoints.shape}')
- # batch_size, num_keypoints, _ = keypoints.shape
- x = keypoints[..., 0]
- y = keypoints[..., 1]
- gs = generate_gaussian_heatmaps(x, y, heatmap_size,num_points=2, sigma=1.0)
- # show_heatmap(gs[0],'target')
- all_roi_heatmap = []
- for roi, heatmap in zip(rois, gs):
- # print(f'heatmap:{heatmap.shape}')
- # show_heatmap(heatmap,'target')
- heatmap = heatmap.unsqueeze(0)
- x1, y1, x2, y2 = map(int, roi)
- roi_heatmap = torch.zeros_like(heatmap)
- roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
- # show_heatmap(roi_heatmap[0],'roi_heatmap')
- all_roi_heatmap.append(roi_heatmap)
- if len(all_roi_heatmap) > 0:
- all_roi_heatmap = torch.cat(all_roi_heatmap)
- print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
- else:
- all_roi_heatmap = None
- return all_roi_heatmap
- """
- 修改适配的原结构的点 转热图,适用于带roi_pool版本的
- """
- def line_points_to_heatmap_(keypoints, rois, heatmap_size):
- # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
- print(f'rois:{rois.shape}')
- print(f'heatmap_size:{heatmap_size}')
- offset_x = rois[:, 0]
- offset_y = rois[:, 1]
- scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
- scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
- offset_x = offset_x[:, None]
- offset_y = offset_y[:, None]
- scale_x = scale_x[:, None]
- scale_y = scale_y[:, None]
- print(f'keypoints.shape:{keypoints.shape}')
- # batch_size, num_keypoints, _ = keypoints.shape
- x = keypoints[..., 0]
- y = keypoints[..., 1]
- # gs=generate_gaussian_heatmaps(x,y,512,1.0)
- # print(f'gs_heatmap shape:{gs.shape}')
- #
- # show_heatmap(gs[0],'target')
- x_boundary_inds = x == rois[:, 2][:, None]
- y_boundary_inds = y == rois[:, 3][:, None]
- x = (x - offset_x) * scale_x
- x = x.floor().long()
- y = (y - offset_y) * scale_y
- y = y.floor().long()
- x[x_boundary_inds] = heatmap_size - 1
- y[y_boundary_inds] = heatmap_size - 1
- # print(f'heatmaps x:{x}')
- # print(f'heatmaps y:{y}')
- valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
- vis = keypoints[..., 2] > 0
- valid = (valid_loc & vis).long()
- gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
- # show_heatmap(gs_heatmap[0], 'feature')
- # print(f'gs_heatmap:{gs_heatmap.shape}')
- #
- # lin_ind = y * heatmap_size + x
- # print(f'lin_ind:{lin_ind.shape}')
- # heatmaps = lin_ind * valid
- return gs_heatmap
- def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
- """
- 为一组点生成并合并高斯热图。
- Args:
- xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
- ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
- heatmap_size (int): 热图大小 H=W
- sigma (float): 高斯核标准差
- device (str): 设备类型 ('cpu' or 'cuda')
- Returns:
- Tensor: 形状为 (H, W) 的合并后的热图
- """
- assert xs.shape == ys.shape, "x and y must have the same shape"
- print(f'xs:{xs.shape}')
- xs=xs.squeeze(1)
- ys = ys.squeeze(1)
- print(f'xs1:{xs.shape}')
- N = xs.shape[0]
- print(f'N:{N},num_points:{num_points}')
- # 创建网格
- grid_y, grid_x = torch.meshgrid(
- torch.arange(heatmap_size, device=device),
- torch.arange(heatmap_size, device=device),
- indexing='ij'
- )
- # print(f'heatmap_size:{heatmap_size}')
- # 初始化输出热图
- combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
- for i in range(N):
- heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
- for j in range(num_points):
- mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
- mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
- # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
- # 计算距离平方
- dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
- # 计算高斯分布
- heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
- heatmap+=heatmap1
- # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
- # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
- #
- # # 计算距离平方
- # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
- #
- # # 计算高斯分布
- # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
- #
- # heatmap = heatmap1 + heatmap2
- # 将当前热图累加到结果中
- combined_heatmap[i] = heatmap
- return combined_heatmap
- def generate_mask_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
- """
- 为一组点生成并合并高斯热图。
- Args:
- xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
- ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
- heatmap_size (int): 热图大小 H=W
- sigma (float): 高斯核标准差
- device (str): 设备类型 ('cpu' or 'cuda')
- Returns:
- Tensor: 形状为 (H, W) 的合并后的热图
- """
- assert xs.shape == ys.shape, "x and y must have the same shape"
- print(f'xs:{xs.shape}')
- xs=xs.squeeze(1)
- ys = ys.squeeze(1)
- print(f'xs1:{xs.shape}')
- N = xs.shape[0]
- print(f'N:{N},num_points:{num_points}')
- # 创建网格
- grid_y, grid_x = torch.meshgrid(
- torch.arange(heatmap_size, device=device),
- torch.arange(heatmap_size, device=device),
- indexing='ij'
- )
- # print(f'heatmap_size:{heatmap_size}')
- # 初始化输出热图
- combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
- for i in range(N):
- heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
- for j in range(num_points):
- mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
- mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
- # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
- # 计算距离平方
- dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
- # 计算高斯分布
- heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
- heatmap+=heatmap1
- # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
- # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
- #
- # # 计算距离平方
- # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
- #
- # # 计算高斯分布
- # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
- #
- # heatmap = heatmap1 + heatmap2
- # 将当前热图累加到结果中
- combined_heatmap[i] = heatmap
- return combined_heatmap
- def non_maximum_suppression(a):
- ap = F.max_pool2d(a, 3, stride=1, padding=1)
- mask = (a == ap).float().clamp(min=0.0)
- return a * mask
- def heatmaps_to_points(maps, rois,num_points=2):
- point_preds = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
- point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device)
- print(f'heatmaps_to_lines:{maps.shape}')
- point_maps=maps[:,0]
- print(f'point_map:{point_maps.shape}')
- for i in range(len(rois)):
- point_roi_map = point_maps[i].unsqueeze(0)
- print(f'point_roi_map:{point_roi_map.shape}')
- # roi_map_probs = scores_to_probs(roi_map.copy())
- w = point_roi_map.shape[2]
- flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
- point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
- print(f'point index:{point_index}')
- # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- point_x =point_index % w
- point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
- point_preds[i, 0,] = point_x
- point_preds[i, 1,] = point_y
- point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
- return point_preds,point_end_scores
- # 分4块
- def find_max_heat_point_in_each_part(feature_map, box):
- """
- 在给定的特征图上,根据box中心点往上移3,往右移3作为新的中心点,
- 并将特征图划分为4个部分,之后在每个部分中找到热度值最大的点。
- Args:
- feature_map (torch.Tensor): 形状为 [C, H, W] 的特征图
- box (torch.Tensor): 形状为 [4] 的边界框 [x_min, y_min, x_max, y_max]
- Returns:
- list: 每个区域中热度最高的点的位置和其对应的热度值 [(y1, x1, heat1), ..., (y4, x4, heat4)]
- """
- device = feature_map.device
- C, H, W = feature_map.shape
- # 计算box的中心点(cx, cy)
- cx = (box[0] + box[2]) // 2
- cy = (box[1] + box[3]) // 2
- # 偏移中心点
- new_cx = min(max(cx + 3, 0), W - 1) # 向右移3
- new_cy = min(max(cy - 3, 0), H - 1) # 向上移3
- # 创建坐标网格
- y_coords, x_coords = torch.meshgrid(
- torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij'
- )
- # 划分四个区域
- mask_q1 = (y_coords < new_cy) & (x_coords < new_cx) # 左上
- mask_q2 = (y_coords < new_cy) & (x_coords >= new_cx) # 右上
- mask_q3 = (y_coords >= new_cy) & (x_coords < new_cx) # 左下
- mask_q4 = (y_coords >= new_cy) & (x_coords >= new_cx) # 右下
- # def process_region(ins):
- # region = feature_map[:, :, ins].squeeze()
- # if len(region.shape) == 0: # 如果区域为空,则跳过
- # return None, None
- # # 找到最大热度值的点及其位置
- # (y, x), heat_val = non_maximum_suppression(region[0])
- # # 将相对坐标转换回全局坐标
- # y_global = y + torch.where(ins)[0].min().item()
- # x_global = x + torch.where(ins)[1].min().item()
- # return (y_global, x_global), heat_val
- #
- # results = []
- # for ins in [mask_q1, mask_q2, mask_q3, mask_q4]:
- # point, heat_val = process_region(ins)
- # if point is not None:
- # # results.append((point[0], point[1], heat_val))
- # results.append((point[0], point[1]))
- # else:
- # results.append(None)
- masks = [mask_q1, mask_q2, mask_q3, mask_q4]
- results = []
- # 假设使用第一个通道作为热力图
- heatmap = feature_map[0] # [H, W]
- def process_region(mask):
- # 应用 ins,只保留该区域
- masked_heatmap = heatmap.clone() # 复制以避免修改原数据
- masked_heatmap[~mask] = 0 # 非区域置0
- def non_maximum_suppression_2d(heatmap, kernel_size=3):
- """
- 对 2D 热力图做非极大值抑制,保留局部最大值点。
- Args:
- heatmap (torch.Tensor): [H, W],输入热力图
- kernel_size (int): 池化窗口大小,用于比较是否为局部最大值
- Returns:
- torch.Tensor: 与 heatmap 同形状的 ins,局部最大值位置为 True
- """
- pad = (kernel_size - 1) // 2
- max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
- maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
- # 局部最大值且值大于0
- peaks = (heatmap == maxima) & (heatmap > 0)
- return peaks
- # 1. 先做 NMS 得到候选局部极大值点
- nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3) # [H, W] bool
- candidate_peaks = masked_heatmap * nms_mask.float() # 只保留 NMS 后的峰值
- # 2. 找出所有候选点中值最大的一个
- if candidate_peaks.max() <= 0:
- return None
- # 找到最大值的位置
- max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
- y, x = divmod(max_idx.item(), W)
- return (x, y) # 返回 (y, x)
- for mask in masks:
- point = process_region(mask)
- results.append(point)
- return results
- def non_maximum_suppression_2d(heatmap, kernel_size=3):
- pad = (kernel_size - 1) // 2
- max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
- maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
- peaks = (heatmap == maxima) & (heatmap > 0)
- return peaks
- def find_max_heat_point_in_edge_centers(feature_map, box):
- device = feature_map.device
- C, H, W = feature_map.shape
- # ¼ÆËã box ÖÐÐÄ
- cx = (box[0] + box[2]) / 2
- cy = (box[1] + box[3]) / 2
- # ¸ù¾Ý box ¿í¸ß¼ÆËã¾Å¹¬¸ñ·Ö½çÏß
- box_width = box[2] - box[0]
- box_height = box[3] - box[1]
- x_left = cx - box_width / 6
- x_right = cx + box_width / 6
- y_top = cy - box_height / 6
- y_bottom = cy + box_height / 6
- # ´´½¨Íø¸ñ
- y_coords, x_coords = torch.meshgrid(
- torch.arange(H, device=device),
- torch.arange(W, device=device),
- indexing='ij'
- )
- # ¶¨ÒåËĸö¡°±ßÖС±ÇøÓòµÄ ins
- mask1 = (x_coords < x_left) & (y_coords < y_top)
- mask_top_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords < y_top)
- mask3 = (x_coords >= x_right) & (y_coords < y_top)
- mask_left_middle = (x_coords < x_left) & (y_coords >= y_top) & (y_coords < y_bottom)
- mask_right_middle = (x_coords >= x_right) & (y_coords >= y_top) & (y_coords < y_bottom)
- mask4 = (x_coords < x_left) & (y_coords >= y_bottom)
- mask_bottom_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords >= y_bottom)
- mask_right_bottom = (x_coords >= x_right) & (y_coords >= y_bottom)
- # masks = [
- # # mask1,
- # mask_top_middle,
- # # mask3,
- # mask_left_middle,
- # mask_right_middle,
- # # mask4,
- # mask_bottom_middle,
- # mask_right_bottom
- # ]
- masks = [
- mask_top_middle,
- mask_right_middle,
- mask_bottom_middle,
- mask_left_middle
- ]
- # ʹÓõÚÒ»¸öͨµÀ×÷ΪÈÈÁ¦Í¼
- heatmap = feature_map[0] # [H, W]
- results = []
- for mask in masks:
- masked_heatmap = heatmap.clone()
- masked_heatmap[~mask] = 0 # ·ÇÄ¿±êÇøÓòÖà 0
- # # NMS ÒÖÖÆ
- # nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3)
- # candidate_peaks = masked_heatmap * nms_mask.float()
- #
- # if candidate_peaks.max() <= 0:
- # results.append(None)
- # continue
- #
- # # ÕÒ×î´óֵλÖÃ
- # max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
- # y, x = divmod(max_idx.item(), W)
- flatten_point_roi_map = masked_heatmap.reshape(1, -1)
- point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
- point_x =point_index % W
- point_y = torch.div(point_index - point_x, W, rounding_mode="floor")
- results.append((point_x, point_y))
- return results # [(y_top, x_top), (y_right, x_right), (y_bottom, x_bottom), (y_left, x_left)]
- def heatmaps_to_circle_points(maps, rois,num_points=2):
- point_preds = torch.zeros((len(rois), 4, 2), dtype=torch.float32, device=maps.device)
- point_end_scores = torch.zeros((len(rois),4, 1), dtype=torch.float32, device=maps.device)
- print(f'rois in heatmaps_to_circle_points:{type(rois), rois.shape}') # <class 'torch.Tensor'>
- print(f'heatmaps_to_lines:{maps.shape}')
- point_maps=maps[:,0]
- print(f'point_map:{point_maps.shape}')
- for i in range(len(rois)):
- point_roi_map = point_maps[i].unsqueeze(0)
- print(f'point_roi_map:{point_roi_map.shape}')
- # roi_map_probs = scores_to_probs(roi_map.copy())
- # w = point_roi_map.shape[2]
- # flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
- # print(f'non_maximum_suppression :{non_maximum_suppression(point_roi_map).shape}')
- # point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
- # print(f'point index:{point_index}')
- # point_x =point_index % w
- # point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
- # print(f'point_x:{point_x}, point_y:{point_y}')
- # point_preds[i, :, 0] = point_x
- # point_preds[i, :, 1] = point_y
- roi1=rois[i]
- result_points = find_max_heat_point_in_edge_centers(non_maximum_suppression(point_roi_map), roi1)
- point_preds[i, :]=torch.tensor(result_points)
- point_x = [point[0] for point in result_points]
- point_y = [point[1] for point in result_points]
- point_end_scores[i, :,0] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
- return point_preds,point_end_scores
- def heatmaps_to_lines(maps, rois):
- line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
- line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
- line_maps=maps[:,1]
- # line_maps = maps.squeeze(1)
- for i in range(len(rois)):
- line_roi_map = line_maps[i].unsqueeze(0)
- print(f'line_roi_map:{line_roi_map.shape}')
- # roi_map_probs = scores_to_probs(roi_map.copy())
- w = line_roi_map.shape[1]
- flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1)
- line_score, line_index = torch.topk(flatten_line_roi_map, k=2)
- print(f'line index:{line_index}')
- # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- pos = line_index
- line_x = pos % w
- line_y = torch.div(pos - line_x, w, rounding_mode="floor")
- line_preds[i, 0, :] = line_x
- line_preds[i, 1, :] = line_y
- line_preds[i, 2, :] = 1
- line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x]
- return line_preds.permute(0, 2, 1), line_end_scores
- # 显示热图的函数
- def show_heatmap(heatmap, title="Heatmap"):
- """
- 使用 matplotlib 显示热图。
- Args:
- heatmap (Tensor): 要显示的热图张量
- title (str): 图表标题
- """
- # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组
- if heatmap.is_cuda:
- heatmap = heatmap.cpu().numpy()
- else:
- heatmap = heatmap.numpy()
- plt.imshow(heatmap, cmap='hot', interpolation='nearest')
- plt.colorbar()
- plt.title(title)
- plt.show()
- def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- N, K, H, W = line_logits.shape
- len_proposals = len(proposals)
- print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals},line_matched_idxs:{line_matched_idxs}')
- if H != W:
- raise ValueError(
- f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
- )
- discretization_size = H
- heatmaps = []
- gs_heatmaps = []
- valid = []
- for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
- print(f'line_proposals_per_image:{proposals_per_image.shape}')
- print(f'gt_lines:{gt_lines}')
- if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
- kp = gt_kp_in_image[midx]
- gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
- gs_heatmaps.append(gs_heatmaps_per_img)
- # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
- # heatmaps.append(heatmaps_per_image.view(-1))
- # valid.append(valid_per_image.view(-1))
- # line_targets = torch.cat(heatmaps, dim=0)
- gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
- print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
- # print(f'line_targets:{line_targets.shape},{line_targets}')
- # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
- # valid = torch.where(valid)[0]
- # print(f' line_targets[valid]:{line_targets[valid]}')
- # torch.mean (in binary_cross_entropy_with_logits) doesn't
- # accept empty tensors, so handle it sepaartely
- # if line_targets.numel() == 0 or len(valid) == 0:
- # return line_logits.sum() * 0
- # line_logits = line_logits.view(N * K, H * W)
- # print(f'line_logits[valid]:{line_logits[valid].shape}')
- print(f'loss1 line_logits:{line_logits.shape}')
- line_logits = line_logits[:,1,:,:]
- # line_logits = line_logits.squeeze(1)
- print(f'loss2 line_logits:{line_logits.shape}')
- # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
- line_loss = F.cross_entropy(line_logits, gs_heatmaps)
- return line_loss
- def compute_ins_loss(feature_logits, proposals, gt_, pos_matched_idxs):
- print(f'compute_arc_loss:{feature_logits.shape}')
- N, K, H, W = feature_logits.shape
- len_proposals = len(proposals)
- empty_count = 0
- non_empty_count = 0
- for prop in proposals:
- if prop.shape[0] == 0:
- empty_count += 1
- else:
- non_empty_count += 1
- print(f"Empty proposals count: {empty_count}")
- print(f"Non-empty proposals count: {non_empty_count}")
- print(f'starte to compute_point_loss')
- print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}')
- if H != W:
- raise ValueError(
- f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
- )
- discretization_size = H
- gs_heatmaps = []
- # print(f'point_matched_idxs:{point_matched_idxs}')
- print(f'gt_masks:{gt_[0].shape}')
- for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs):
- # [
- # (Tensor(38, 4), Tensor(1, 57, 2), Tensor(38, 1)),
- # (Tensor(65, 4), Tensor(1, 74, 2), Tensor(65, 1))
- # ]
- print(f'proposals_per_image:{proposals_per_image.shape}')
- kp = gt_kp_in_image[midx]
- t_h, t_w = kp.shape[-2:]
- print(f't_h:{t_h}, t_w:{t_w}')
- print(f'gt_kp_in_image:{gt_kp_in_image.shape}')
- if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
- gs_heatmaps_per_img = align_masks(kp, proposals_per_image, discretization_size)
- gs_heatmaps.append(gs_heatmaps_per_img)
- if len(gs_heatmaps)>0:
- gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
- print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}')
- line_logits = feature_logits.squeeze(1)
- print(f'ins shape:{line_logits.shape}')
- # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
- # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
- # save_full_mask(line_logits,"line_logits",out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_loss")
- # save_full_mask(gs_heatmaps,"gs_heatmaps",out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_loss")
- line_loss=combined_loss(line_logits, gs_heatmaps)
- else:
- line_loss=100
- print("d")
- return line_loss
- def align_masks(keypoints, rois, heatmap_size):
- print(f'rois:{rois.shape}')
- print(f'heatmap_size:{heatmap_size}')
- print(f'keypoints.shape:{keypoints.shape}')
- # t_h, t_w = keypoints.shape[-2:]
- # scale=heatmap_size/t_w
- # print(f'scale:{scale}')
- # x = keypoints[..., 0]*scale
- # y = keypoints[..., 1]*scale
- #
- # x = x.unsqueeze(1)
- # y = y.unsqueeze(1)
- #
- # num_points=x.shape[2]
- # print(f'num_points:{num_points}')
- # plt.imshow(keypoints[0].cpu())
- # plt.show()
- mask_4d = keypoints.unsqueeze(1).float()
- resized_mask = F.interpolate(
- mask_4d,
- size = (heatmap_size, heatmap_size),
- mode = 'bilinear',
- align_corners = False
- ).squeeze(1) # [B,heatmap_size,heatmap_size]
- # plt.imshow(resized_mask[0].cpu())
- # plt.show()
- print(f'resized_mask:{resized_mask.shape}')
- return resized_mask
- def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- N, K, H, W = line_logits.shape
- len_proposals = len(proposals)
- empty_count = 0
- non_empty_count = 0
- for prop in proposals:
- if prop.shape[0] == 0:
- empty_count += 1
- else:
- non_empty_count += 1
- print(f"Empty proposals count: {empty_count}")
- print(f"Non-empty proposals count: {non_empty_count}")
- print(f'starte to compute_point_loss')
- print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
- if H != W:
- raise ValueError(
- f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
- )
- discretization_size = H
- gs_heatmaps = []
- # print(f'point_matched_idxs:{point_matched_idxs}')
- for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs):
- print(f'proposals_per_image:{proposals_per_image.shape}')
- kp = gt_kp_in_image[midx]
- # print(f'gt_kp_in_image:{gt_kp_in_image}')
- gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size)
- gs_heatmaps.append(gs_heatmaps_per_img)
- gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
- print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
- line_logits = line_logits[:,0]
- print(f'single_point_logits:{line_logits.shape}')
- line_loss = F.cross_entropy(line_logits, gs_heatmaps)
- return line_loss
- def compute_circle_loss(circle_logits, proposals, gt_circles, circle_matched_idxs):
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- N, K, H, W = circle_logits.shape
- len_proposals = len(proposals)
- empty_count = 0
- non_empty_count = 0
- for prop in proposals:
- if prop.shape[0] == 0:
- empty_count += 1
- else:
- non_empty_count += 1
- print(f"Empty proposals count: {empty_count}")
- print(f"Non-empty proposals count: {non_empty_count}")
- print(f'starte to compute_circle_loss')
- print(f'compute_circle_loss circle_logits.shape:{circle_logits.shape},len_proposals:{len_proposals}')
- if H != W:
- raise ValueError(
- f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
- )
- discretization_size = H
- gs_heatmaps = []
- # print(f'point_matched_idxs:{point_matched_idxs}')
- for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_circles, circle_matched_idxs):
- print(f'proposals_per_image:{proposals_per_image.shape}')
- kp = gt_kp_in_image[midx]
- # print(f'gt_kp_in_image:{gt_kp_in_image}')
- gs_heatmaps_per_img = points_to_heatmap(kp, proposals_per_image,num_points=4, heatmap_size=discretization_size)
- gs_heatmaps.append(gs_heatmaps_per_img)
- gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
- print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{circle_logits.squeeze(1).shape}')
- circle_logits = circle_logits[:, 0]
- print(f'circle_logits:{circle_logits.shape}')
- circle_loss = F.cross_entropy(circle_logits, gs_heatmaps)
- return circle_loss
- def lines_to_boxes(lines, img_size=511):
- """
- 输入:
- lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
- img_size: int,图像尺寸,用于 clamp 边界
- 输出:
- boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
- """
- # 提取所有线段的两个端点
- p1 = lines[:, 0] # (N, 2)
- p2 = lines[:, 1] # (N, 2)
- # 每条线段的 x 和 y 坐标
- x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1) # (N, 2)
- y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1) # (N, 2)
- # 计算包围盒边界
- x_min = x_coords.min(dim=1).values
- y_min = y_coords.min(dim=1).values
- x_max = x_coords.max(dim=1).values
- y_max = y_coords.max(dim=1).values
- # 扩展边界并限制在图像范围内
- x_min = (x_min - 1).clamp(min=0, max=img_size)
- y_min = (y_min - 1).clamp(min=0, max=img_size)
- x_max = (x_max + 1).clamp(min=0, max=img_size)
- y_max = (y_max + 1).clamp(min=0, max=img_size)
- # 合成包围盒
- boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1) # (N, 4)
- return boxes
- def box_iou_pairwise(box1, box2):
- """
- 输入:
- box1: shape (N, 4)
- box2: shape (M, 4)
- 输出:
- ious: shape (min(N, M), ), 只计算 i = j 的配对
- """
- N = min(len(box1), len(box2))
- lt = torch.max(box1[:N, :2], box2[:N, :2]) # å·¦ä¸è§
- rb = torch.min(box1[:N, 2:], box2[:N, 2:]) # å³ä¸è§
- wh = (rb - lt).clamp(min=0) # 宽é«
- inter_area = wh[:, 0] * wh[:, 1] # 交éé¢ç§¯
- area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
- area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
- union_area = area1 + area2 - inter_area
- ious = inter_area / (union_area + 1e-6)
- return ious
- def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0):
- """
- Args:
- x: [N,1,H,W] 热力图
- boxes: [N,4] 框坐标
- gt_lines: [N,2,3] GT线段(含可见性)
- matched_idx: 匹配 index
- img_size: 图像尺寸
- alpha: IoU 损失权重
- beta: 长度损失权重
- gamma: 方向角度损失权重
- """
- losses = []
- boxes_per_image = [box.size(0) for box in boxes]
- x2 = x.split(boxes_per_image, dim=0)
- for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
- p_prob, _ = heatmaps_to_lines(xx, bb)
- pred_lines = p_prob
- gt_line_points = gt_line[mid]
- if len(pred_lines) == 0 or len(gt_line_points) == 0:
- continue
- # IoU 损失
- pred_boxes = lines_to_boxes(pred_lines, img_size)
- gt_boxes = lines_to_boxes(gt_line_points, img_size)
- ious = box_iou_pairwise(pred_boxes, gt_boxes)
- iou_loss = 1.0 - ious # [N]
- # 长度损失
- pred_len = line_length(pred_lines)
- gt_len = line_length(gt_line_points)
- length_diff = F.l1_loss(pred_len, gt_len, reduction='none') # [N]
- # 方向角度损失
- pred_dir = line_direction(pred_lines)
- gt_dir = line_direction(gt_line_points)
- ang_loss = angle_loss_cosine(pred_dir, gt_dir) # [N]
- # 归一化每一项损失
- norm_iou = normalize_tensor(iou_loss)
- norm_len = normalize_tensor(length_diff)
- norm_ang = normalize_tensor(ang_loss)
- total = alpha * norm_iou + beta * norm_len + gamma * norm_ang
- losses.append(total)
- if not losses:
- return None
- return torch.mean(torch.cat(losses))
- def point_inference(x, point_boxes):
- # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- points_probs = []
- points_scores = []
- boxes_per_image = [box.size(0) for box in point_boxes]
- x2 = x.split(boxes_per_image, dim=0)
- for xx, bb in zip(x2, point_boxes):
- point_prob,point_scores = heatmaps_to_points(xx, bb,num_points=1)
- points_probs.append(point_prob.unsqueeze(1))
- points_scores.append(point_scores)
- return points_probs,points_scores
- def circle_inference(x, point_boxes):
- # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- points_probs = []
- points_scores = []
- boxes_per_image = [box.size(0) for box in point_boxes]
- x2 = x.split(boxes_per_image, dim=0)
- for xx, bb in zip(x2, point_boxes):
- point_prob,point_scores = heatmaps_to_circle_points(xx, bb,num_points=4)
- points_probs.append(point_prob.unsqueeze(1))
- points_scores.append(point_scores)
- return points_probs,points_scores
- def line_inference(x, line_boxes):
- # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- lines_probs = []
- lines_scores = []
- boxes_per_image = [box.size(0) for box in line_boxes]
- x2 = x.split(boxes_per_image, dim=0)
- # x2:tuple 2 x2[0]:[1,3,1024,1024]
- # line_box: list:2 [1,4] [1.4] fasterrcnn kuang
- for xx, bb in zip(x2, line_boxes):
- line_prob, line_scores, = heatmaps_to_lines(xx, bb)
- lines_probs.append(line_prob)
- lines_scores.append(line_scores)
- return lines_probs, lines_scores
- def ins_inference(x, arc_boxes, th):
- # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- points_probs = []
- points_scores = []
- print(f'arc_boxes:{len(arc_boxes)}')
- boxes_per_image = [box.size(0) for box in arc_boxes]
- print(f'arc boxes_per_image:{boxes_per_image}')
- x2 = x.split(boxes_per_image, dim=0)
- for xx, bb in zip(x2, arc_boxes):
- point_prob,point_scores = heatmaps_to_arc(xx, bb)
- points_probs.append(point_prob.unsqueeze(1))
- points_scores.append(point_scores)
- points_probs_tensor=torch.cat(points_probs)
- print(f'points_probs shape:{points_probs_tensor.shape}')
- feature_logits = x
- batch_size = feature_logits.shape[0]
- num_proposals = len(arc_boxes[0])
- results = [[torch.empty(0, 2) for _ in range(num_proposals)] for _ in range(batch_size)]
- proposals_list = arc_boxes[0] # [[tensor(...)]]
- for proposal_idx, proposal in enumerate(proposals_list):
- coords = proposal.tolist()
- x1, y1, x2, y2 = map(int, coords)
- x1 = max(0, x1)
- y1 = max(0, y1)
- x2 = min(feature_logits.shape[3], x2)
- y2 = min(feature_logits.shape[2], y2)
- for batch_idx in range(batch_size):
- region = feature_logits[batch_idx, :, y1:y2, x1:x2]
- mask = region > th
- coords = torch.nonzero(mask)
- if coords.numel() > 0:
- # 取 (y, x),然后转换为全局坐标 (x, y)
- local_coords = coords[:, [2, 1]] # (x, y)
- local_coords[:, 0] += x1
- local_coords[:, 1] += y1
- results[batch_idx][proposal_idx] = local_coords
- print(f're:{results}')
- return points_probs,points_scores,results
- import torch.nn.functional as F
- def heatmaps_to_arc(maps, rois, threshold=0.5, output_size=(128, 128)):
- """
- Args:
- maps: [N, 3, H, W] - full heatmaps
- rois: [N, 4] - bounding boxes
- threshold: float - binarization threshold
- output_size: resized size for uniform NMS
- Returns:
- masks: [N, 1, H, W] - binary ins aligned with input map
- scores: [N, 1] - count of non-zero pixels in each ins
- """
- N, _, H, W = maps.shape
- masks = torch.zeros((N, 1, H, W), dtype=torch.float32, device=maps.device)
- scores = torch.zeros((N, 1), dtype=torch.float32, device=maps.device)
- point_maps = maps[:, 0] # È¡µÚÒ»¸öͨµÀ [N, H, W]
- print(f"==> heatmaps_to_arc: maps.shape = {maps.shape}, rois.shape = {rois.shape}")
- for i in range(N):
- x1, y1, x2, y2 = rois[i].long()
- x1 = x1.clamp(0, W - 1)
- x2 = x2.clamp(0, W - 1)
- y1 = y1.clamp(0, H - 1)
- y2 = y2.clamp(0, H - 1)
- print(f"[{i}] roi: ({x1.item()}, {y1.item()}, {x2.item()}, {y2.item()})")
- if x2 <= x1 or y2 <= y1:
- print(f" Skipped invalid ROI at index {i}")
- continue
- roi_map = point_maps[i, y1:y2, x1:x2] # [h, w]
- print(f" roi_map.shape: {roi_map.shape}")
- if roi_map.numel() == 0:
- print(f" Skipped empty ROI at index {i}")
- continue
- # resize to uniform size
- roi_map_resized = F.interpolate(
- roi_map.unsqueeze(0).unsqueeze(0),
- size=output_size,
- mode='bilinear',
- align_corners=False
- ) # [1, 1, H, W]
- print(f" roi_map_resized.shape: {roi_map_resized.shape}")
- # NMS + threshold
- # nms_roi = non_maximum_suppression(roi_map_resized) # shape: [1, H, W]
- nms_roi = torch.sigmoid(roi_map_resized)
- bin_mask = (nms_roi >= threshold).float() # shape: [1, H, W]
- print(f" bin_mask.sum(): {bin_mask.sum().item()}")
- # resize back to original roi size
- h = int((y2 - y1).item())
- w = int((x2 - x1).item())
- # È·±£ bin_mask ÊÇ [1, 128, 128]
- assert bin_mask.dim() == 4, f"Expected 3D tensor [1, H, W], got {bin_mask.shape}"
- bin_mask_original_size = F.interpolate(
- bin_mask, # [1, 1, 128, 128]
- size=(h, w),
- mode='bilinear',
- align_corners=False
- )[0] # [1, h, w]
- bin_mask_original_size = bin_mask_original_size[0] # [h, w]
- masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size
- scores[i] = bin_mask_original_size.sum()
- #
- # bin_mask_original_size = F.interpolate(
- # # bin_mask.unsqueeze(0), # ? [1, 1, 128, 128]
- # bin_mask, # ? [1, 1, 128, 128]
- # size=(h, w),
- # mode='bilinear',
- # align_corners=False
- # )[0] # ? [1, h, w]
- #
- # masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size.squeeze()
- # scores[i] = bin_mask_original_size.sum()
- # plt.figure(figsize=(6, 6))
- # plt.imshow(masks[i, 0].cpu().numpy(), cmap='gray')
- # plt.title(f"Mask {i}, score={scores[i].item():.1f}")
- # plt.axis('off')
- # plt.show()
- print(f" bin_mask_original_size.shape: {bin_mask_original_size.shape}, sum: {scores[i].item()}")
- print(f"==> Done. Total valid masks: {(scores > 0).sum().item()} / {N}")
- return masks, scores
|