head_losses.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280
  1. import torch
  2. from matplotlib import pyplot as plt
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from torch.cuda import device
  6. class DiceLoss(nn.Module):
  7. def __init__(self, smooth=1.):
  8. super(DiceLoss, self).__init__()
  9. self.smooth = smooth
  10. def forward(self, logits, targets):
  11. probs = torch.sigmoid(logits)
  12. probs = probs.view(-1)
  13. targets = targets.view(-1).float()
  14. intersection = (probs * targets).sum()
  15. dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
  16. return 1. - dice
  17. bce_loss = nn.BCEWithLogitsLoss()
  18. dice_loss = DiceLoss()
  19. def combined_loss(preds, targets, alpha=0.5):
  20. bce = bce_loss(preds, targets)
  21. d = dice_loss(preds, targets)
  22. return alpha * bce + (1 - alpha) * d
  23. def features_align(features, proposals, img_size):
  24. print(f'features_align features:{features.shape},proposals:{len(proposals)}')
  25. align_feat_list = []
  26. for feat, proposals_per_img in zip(features, proposals):
  27. print(f'feature_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
  28. if proposals_per_img.shape[0]>0:
  29. feat = feat.unsqueeze(0)
  30. for proposal in proposals_per_img:
  31. align_feat = torch.zeros_like(feat)
  32. # print(f'align_feat:{align_feat.shape}')
  33. x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
  34. # 将每个proposal框内的部分赋值到align_feats对应位置
  35. align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
  36. align_feat_list.append(align_feat)
  37. # print(f'align_feat_list:{align_feat_list}')
  38. if len(align_feat_list) > 0:
  39. feats_tensor = torch.cat(align_feat_list)
  40. print(f'align features :{feats_tensor.shape}')
  41. else:
  42. feats_tensor = None
  43. return feats_tensor
  44. def normalize_tensor(t):
  45. return (t - t.min()) / (t.max() - t.min() + 1e-6)
  46. def line_length(lines):
  47. """
  48. 计算每条线段的长度
  49. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  50. 返回: [N]
  51. """
  52. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  53. def line_direction(lines):
  54. """
  55. 计算每条线段的单位方向向量
  56. lines: [N, 2, 2]
  57. 返回: [N, 2] 单位方向向量
  58. """
  59. vec = lines[:, 1] - lines[:, 0]
  60. return F.normalize(vec, dim=-1)
  61. def angle_loss_cosine(pred_dir, gt_dir):
  62. """
  63. 使用 cosine similarity 计算方向差异
  64. pred_dir: [N, 2]
  65. gt_dir: [N, 2]
  66. 返回: [N]
  67. """
  68. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  69. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  70. def line_length(lines):
  71. """
  72. 计算每条线段的长度
  73. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  74. 返回: [N]
  75. """
  76. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  77. def line_direction(lines):
  78. """
  79. 计算每条线段的单位方向向量
  80. lines: [N, 2, 2]
  81. 返回: [N, 2] 单位方向向量
  82. """
  83. vec = lines[:, 1] - lines[:, 0]
  84. return F.normalize(vec, dim=-1)
  85. def angle_loss_cosine(pred_dir, gt_dir):
  86. """
  87. 使用 cosine similarity 计算方向差异
  88. pred_dir: [N, 2]
  89. gt_dir: [N, 2]
  90. 返回: [N]
  91. """
  92. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  93. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  94. def single_point_to_heatmap(keypoints, rois, heatmap_size):
  95. # type: (Tensor, Tensor, int) -> Tensor
  96. print(f'rois:{rois.shape}')
  97. print(f'heatmap_size:{heatmap_size}')
  98. print(f'keypoints.shape:{keypoints.shape}')
  99. # batch_size, num_keypoints, _ = keypoints.shape
  100. x = keypoints[..., 0].unsqueeze(1)
  101. y = keypoints[..., 1].unsqueeze(1)
  102. gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
  103. # show_heatmap(gs[0],'target')
  104. all_roi_heatmap = []
  105. for roi, heatmap in zip(rois, gs):
  106. # show_heatmap(heatmap, 'target')
  107. # print(f'heatmap:{heatmap.shape}')
  108. heatmap = heatmap.unsqueeze(0)
  109. x1, y1, x2, y2 = map(int, roi)
  110. roi_heatmap = torch.zeros_like(heatmap)
  111. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  112. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  113. all_roi_heatmap.append(roi_heatmap)
  114. all_roi_heatmap = torch.cat(all_roi_heatmap)
  115. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  116. return all_roi_heatmap
  117. def points_to_heatmap(keypoints, rois,num_points=2, heatmap_size=(512,512)):
  118. # type: (Tensor, Tensor, int) -> Tensor
  119. print(f'rois:{rois.shape}')
  120. print(f'heatmap_size:{heatmap_size}')
  121. print(f'keypoints.shape:{keypoints.shape}')
  122. # batch_size, num_keypoints, _ = keypoints.shape
  123. x = keypoints[..., 0].unsqueeze(1)
  124. y = keypoints[..., 1].unsqueeze(1)
  125. gs = generate_gaussian_heatmaps(x, y,num_points=num_points, heatmap_size=heatmap_size, sigma=2.0)
  126. # show_heatmap(gs[0],'target')
  127. all_roi_heatmap = []
  128. for roi, heatmap in zip(rois, gs):
  129. # show_heatmap(heatmap, 'target')
  130. # print(f'heatmap:{heatmap.shape}')
  131. heatmap = heatmap.unsqueeze(0)
  132. x1, y1, x2, y2 = map(int, roi)
  133. roi_heatmap = torch.zeros_like(heatmap)
  134. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  135. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  136. all_roi_heatmap.append(roi_heatmap)
  137. all_roi_heatmap = torch.cat(all_roi_heatmap)
  138. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  139. return all_roi_heatmap
  140. def line_points_to_heatmap(keypoints, rois, heatmap_size):
  141. # type: (Tensor, Tensor, int) -> Tensor
  142. print(f'rois:{rois.shape}')
  143. print(f'heatmap_size:{heatmap_size}')
  144. print(f'keypoints.shape:{keypoints.shape}')
  145. # batch_size, num_keypoints, _ = keypoints.shape
  146. x = keypoints[..., 0]
  147. y = keypoints[..., 1]
  148. gs = generate_gaussian_heatmaps(x, y, heatmap_size,num_points=2, sigma=1.0)
  149. # show_heatmap(gs[0],'target')
  150. all_roi_heatmap = []
  151. for roi, heatmap in zip(rois, gs):
  152. # print(f'heatmap:{heatmap.shape}')
  153. # show_heatmap(heatmap,'target')
  154. heatmap = heatmap.unsqueeze(0)
  155. x1, y1, x2, y2 = map(int, roi)
  156. roi_heatmap = torch.zeros_like(heatmap)
  157. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  158. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  159. all_roi_heatmap.append(roi_heatmap)
  160. if len(all_roi_heatmap) > 0:
  161. all_roi_heatmap = torch.cat(all_roi_heatmap)
  162. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  163. else:
  164. all_roi_heatmap = None
  165. return all_roi_heatmap
  166. """
  167. 修改适配的原结构的点 转热图,适用于带roi_pool版本的
  168. """
  169. def line_points_to_heatmap_(keypoints, rois, heatmap_size):
  170. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  171. print(f'rois:{rois.shape}')
  172. print(f'heatmap_size:{heatmap_size}')
  173. offset_x = rois[:, 0]
  174. offset_y = rois[:, 1]
  175. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  176. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  177. offset_x = offset_x[:, None]
  178. offset_y = offset_y[:, None]
  179. scale_x = scale_x[:, None]
  180. scale_y = scale_y[:, None]
  181. print(f'keypoints.shape:{keypoints.shape}')
  182. # batch_size, num_keypoints, _ = keypoints.shape
  183. x = keypoints[..., 0]
  184. y = keypoints[..., 1]
  185. # gs=generate_gaussian_heatmaps(x,y,512,1.0)
  186. # print(f'gs_heatmap shape:{gs.shape}')
  187. #
  188. # show_heatmap(gs[0],'target')
  189. x_boundary_inds = x == rois[:, 2][:, None]
  190. y_boundary_inds = y == rois[:, 3][:, None]
  191. x = (x - offset_x) * scale_x
  192. x = x.floor().long()
  193. y = (y - offset_y) * scale_y
  194. y = y.floor().long()
  195. x[x_boundary_inds] = heatmap_size - 1
  196. y[y_boundary_inds] = heatmap_size - 1
  197. # print(f'heatmaps x:{x}')
  198. # print(f'heatmaps y:{y}')
  199. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  200. vis = keypoints[..., 2] > 0
  201. valid = (valid_loc & vis).long()
  202. gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
  203. # show_heatmap(gs_heatmap[0], 'feature')
  204. # print(f'gs_heatmap:{gs_heatmap.shape}')
  205. #
  206. # lin_ind = y * heatmap_size + x
  207. # print(f'lin_ind:{lin_ind.shape}')
  208. # heatmaps = lin_ind * valid
  209. return gs_heatmap
  210. def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  211. """
  212. 为一组点生成并合并高斯热图。
  213. Args:
  214. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  215. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  216. heatmap_size (int): 热图大小 H=W
  217. sigma (float): 高斯核标准差
  218. device (str): 设备类型 ('cpu' or 'cuda')
  219. Returns:
  220. Tensor: 形状为 (H, W) 的合并后的热图
  221. """
  222. assert xs.shape == ys.shape, "x and y must have the same shape"
  223. print(f'xs:{xs.shape}')
  224. xs=xs.squeeze(1)
  225. ys = ys.squeeze(1)
  226. print(f'xs1:{xs.shape}')
  227. N = xs.shape[0]
  228. print(f'N:{N},num_points:{num_points}')
  229. # 创建网格
  230. grid_y, grid_x = torch.meshgrid(
  231. torch.arange(heatmap_size, device=device),
  232. torch.arange(heatmap_size, device=device),
  233. indexing='ij'
  234. )
  235. # print(f'heatmap_size:{heatmap_size}')
  236. # 初始化输出热图
  237. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  238. for i in range(N):
  239. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  240. for j in range(num_points):
  241. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  242. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  243. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  244. # 计算距离平方
  245. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  246. # 计算高斯分布
  247. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  248. heatmap+=heatmap1
  249. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  250. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  251. #
  252. # # 计算距离平方
  253. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  254. #
  255. # # 计算高斯分布
  256. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  257. #
  258. # heatmap = heatmap1 + heatmap2
  259. # 将当前热图累加到结果中
  260. combined_heatmap[i] = heatmap
  261. return combined_heatmap
  262. def generate_mask_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  263. """
  264. 为一组点生成并合并高斯热图。
  265. Args:
  266. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  267. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  268. heatmap_size (int): 热图大小 H=W
  269. sigma (float): 高斯核标准差
  270. device (str): 设备类型 ('cpu' or 'cuda')
  271. Returns:
  272. Tensor: 形状为 (H, W) 的合并后的热图
  273. """
  274. assert xs.shape == ys.shape, "x and y must have the same shape"
  275. print(f'xs:{xs.shape}')
  276. xs=xs.squeeze(1)
  277. ys = ys.squeeze(1)
  278. print(f'xs1:{xs.shape}')
  279. N = xs.shape[0]
  280. print(f'N:{N},num_points:{num_points}')
  281. # 创建网格
  282. grid_y, grid_x = torch.meshgrid(
  283. torch.arange(heatmap_size, device=device),
  284. torch.arange(heatmap_size, device=device),
  285. indexing='ij'
  286. )
  287. # print(f'heatmap_size:{heatmap_size}')
  288. # 初始化输出热图
  289. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  290. for i in range(N):
  291. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  292. for j in range(num_points):
  293. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  294. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  295. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  296. # 计算距离平方
  297. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  298. # 计算高斯分布
  299. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  300. heatmap+=heatmap1
  301. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  302. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  303. #
  304. # # 计算距离平方
  305. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  306. #
  307. # # 计算高斯分布
  308. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  309. #
  310. # heatmap = heatmap1 + heatmap2
  311. # 将当前热图累加到结果中
  312. combined_heatmap[i] = heatmap
  313. return combined_heatmap
  314. def non_maximum_suppression(a):
  315. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  316. mask = (a == ap).float().clamp(min=0.0)
  317. return a * mask
  318. def heatmaps_to_points(maps, rois,num_points=2):
  319. point_preds = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  320. point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device)
  321. print(f'heatmaps_to_lines:{maps.shape}')
  322. point_maps=maps[:,0]
  323. print(f'point_map:{point_maps.shape}')
  324. for i in range(len(rois)):
  325. point_roi_map = point_maps[i].unsqueeze(0)
  326. print(f'point_roi_map:{point_roi_map.shape}')
  327. # roi_map_probs = scores_to_probs(roi_map.copy())
  328. w = point_roi_map.shape[2]
  329. flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
  330. point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
  331. print(f'point index:{point_index}')
  332. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  333. point_x =point_index % w
  334. point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
  335. point_preds[i, 0,] = point_x
  336. point_preds[i, 1,] = point_y
  337. point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
  338. return point_preds,point_end_scores
  339. # 分4块
  340. def find_max_heat_point_in_each_part(feature_map, box):
  341. """
  342. 在给定的特征图上,根据box中心点往上移3,往右移3作为新的中心点,
  343. 并将特征图划分为4个部分,之后在每个部分中找到热度值最大的点。
  344. Args:
  345. feature_map (torch.Tensor): 形状为 [C, H, W] 的特征图
  346. box (torch.Tensor): 形状为 [4] 的边界框 [x_min, y_min, x_max, y_max]
  347. Returns:
  348. list: 每个区域中热度最高的点的位置和其对应的热度值 [(y1, x1, heat1), ..., (y4, x4, heat4)]
  349. """
  350. device = feature_map.device
  351. C, H, W = feature_map.shape
  352. # 计算box的中心点(cx, cy)
  353. cx = (box[0] + box[2]) // 2
  354. cy = (box[1] + box[3]) // 2
  355. # 偏移中心点
  356. new_cx = min(max(cx + 3, 0), W - 1) # 向右移3
  357. new_cy = min(max(cy - 3, 0), H - 1) # 向上移3
  358. # 创建坐标网格
  359. y_coords, x_coords = torch.meshgrid(
  360. torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij'
  361. )
  362. # 划分四个区域
  363. mask_q1 = (y_coords < new_cy) & (x_coords < new_cx) # 左上
  364. mask_q2 = (y_coords < new_cy) & (x_coords >= new_cx) # 右上
  365. mask_q3 = (y_coords >= new_cy) & (x_coords < new_cx) # 左下
  366. mask_q4 = (y_coords >= new_cy) & (x_coords >= new_cx) # 右下
  367. # def process_region(mask):
  368. # region = feature_map[:, :, mask].squeeze()
  369. # if len(region.shape) == 0: # 如果区域为空,则跳过
  370. # return None, None
  371. # # 找到最大热度值的点及其位置
  372. # (y, x), heat_val = non_maximum_suppression(region[0])
  373. # # 将相对坐标转换回全局坐标
  374. # y_global = y + torch.where(mask)[0].min().item()
  375. # x_global = x + torch.where(mask)[1].min().item()
  376. # return (y_global, x_global), heat_val
  377. #
  378. # results = []
  379. # for mask in [mask_q1, mask_q2, mask_q3, mask_q4]:
  380. # point, heat_val = process_region(mask)
  381. # if point is not None:
  382. # # results.append((point[0], point[1], heat_val))
  383. # results.append((point[0], point[1]))
  384. # else:
  385. # results.append(None)
  386. masks = [mask_q1, mask_q2, mask_q3, mask_q4]
  387. results = []
  388. # 假设使用第一个通道作为热力图
  389. heatmap = feature_map[0] # [H, W]
  390. def process_region(mask):
  391. # 应用 mask,只保留该区域
  392. masked_heatmap = heatmap.clone() # 复制以避免修改原数据
  393. masked_heatmap[~mask] = 0 # 非区域置0
  394. def non_maximum_suppression_2d(heatmap, kernel_size=3):
  395. """
  396. 对 2D 热力图做非极大值抑制,保留局部最大值点。
  397. Args:
  398. heatmap (torch.Tensor): [H, W],输入热力图
  399. kernel_size (int): 池化窗口大小,用于比较是否为局部最大值
  400. Returns:
  401. torch.Tensor: 与 heatmap 同形状的 mask,局部最大值位置为 True
  402. """
  403. pad = (kernel_size - 1) // 2
  404. max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
  405. maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
  406. # 局部最大值且值大于0
  407. peaks = (heatmap == maxima) & (heatmap > 0)
  408. return peaks
  409. # 1. 先做 NMS 得到候选局部极大值点
  410. nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3) # [H, W] bool
  411. candidate_peaks = masked_heatmap * nms_mask.float() # 只保留 NMS 后的峰值
  412. # 2. 找出所有候选点中值最大的一个
  413. if candidate_peaks.max() <= 0:
  414. return None
  415. # 找到最大值的位置
  416. max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
  417. y, x = divmod(max_idx.item(), W)
  418. return (x, y) # 返回 (y, x)
  419. for mask in masks:
  420. point = process_region(mask)
  421. results.append(point)
  422. return results
  423. def non_maximum_suppression_2d(heatmap, kernel_size=3):
  424. pad = (kernel_size - 1) // 2
  425. max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
  426. maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
  427. peaks = (heatmap == maxima) & (heatmap > 0)
  428. return peaks
  429. def find_max_heat_point_in_edge_centers(feature_map, box):
  430. device = feature_map.device
  431. C, H, W = feature_map.shape
  432. # ¼ÆËã box ÖÐÐÄ
  433. cx = (box[0] + box[2]) / 2
  434. cy = (box[1] + box[3]) / 2
  435. # ¸ù¾Ý box ¿í¸ß¼ÆËã¾Å¹¬¸ñ·Ö½çÏß
  436. box_width = box[2] - box[0]
  437. box_height = box[3] - box[1]
  438. x_left = cx - box_width / 6
  439. x_right = cx + box_width / 6
  440. y_top = cy - box_height / 6
  441. y_bottom = cy + box_height / 6
  442. # ´´½¨Íø¸ñ
  443. y_coords, x_coords = torch.meshgrid(
  444. torch.arange(H, device=device),
  445. torch.arange(W, device=device),
  446. indexing='ij'
  447. )
  448. # ¶¨ÒåËĸö¡°±ßÖС±ÇøÓòµÄ mask
  449. mask1 = (x_coords < x_left) & (y_coords < y_top)
  450. mask_top_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords < y_top)
  451. mask3 = (x_coords >= x_right) & (y_coords < y_top)
  452. mask_left_middle = (x_coords < x_left) & (y_coords >= y_top) & (y_coords < y_bottom)
  453. mask_right_middle = (x_coords >= x_right) & (y_coords >= y_top) & (y_coords < y_bottom)
  454. mask4 = (x_coords < x_left) & (y_coords >= y_bottom)
  455. mask_bottom_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords >= y_bottom)
  456. mask_right_bottom = (x_coords >= x_right) & (y_coords >= y_bottom)
  457. # masks = [
  458. # # mask1,
  459. # mask_top_middle,
  460. # # mask3,
  461. # mask_left_middle,
  462. # mask_right_middle,
  463. # # mask4,
  464. # mask_bottom_middle,
  465. # mask_right_bottom
  466. # ]
  467. masks = [
  468. mask_top_middle,
  469. mask_right_middle,
  470. mask_bottom_middle,
  471. mask_left_middle
  472. ]
  473. # ʹÓõÚÒ»¸öͨµÀ×÷ΪÈÈÁ¦Í¼
  474. heatmap = feature_map[0] # [H, W]
  475. results = []
  476. for mask in masks:
  477. masked_heatmap = heatmap.clone()
  478. masked_heatmap[~mask] = 0 # ·ÇÄ¿±êÇøÓòÖà 0
  479. # # NMS ÒÖÖÆ
  480. # nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3)
  481. # candidate_peaks = masked_heatmap * nms_mask.float()
  482. #
  483. # if candidate_peaks.max() <= 0:
  484. # results.append(None)
  485. # continue
  486. #
  487. # # ÕÒ×î´óֵλÖÃ
  488. # max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
  489. # y, x = divmod(max_idx.item(), W)
  490. flatten_point_roi_map = masked_heatmap.reshape(1, -1)
  491. point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
  492. point_x =point_index % W
  493. point_y = torch.div(point_index - point_x, W, rounding_mode="floor")
  494. results.append((point_x, point_y))
  495. return results # [(y_top, x_top), (y_right, x_right), (y_bottom, x_bottom), (y_left, x_left)]
  496. def heatmaps_to_circle_points(maps, rois,num_points=2):
  497. point_preds = torch.zeros((len(rois), 4, 2), dtype=torch.float32, device=maps.device)
  498. point_end_scores = torch.zeros((len(rois),4, 1), dtype=torch.float32, device=maps.device)
  499. print(f'rois in heatmaps_to_circle_points:{type(rois), rois.shape}') # <class 'torch.Tensor'>
  500. print(f'heatmaps_to_lines:{maps.shape}')
  501. point_maps=maps[:,0]
  502. print(f'point_map:{point_maps.shape}')
  503. for i in range(len(rois)):
  504. point_roi_map = point_maps[i].unsqueeze(0)
  505. print(f'point_roi_map:{point_roi_map.shape}')
  506. # roi_map_probs = scores_to_probs(roi_map.copy())
  507. # w = point_roi_map.shape[2]
  508. # flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
  509. # print(f'non_maximum_suppression :{non_maximum_suppression(point_roi_map).shape}')
  510. # point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
  511. # print(f'point index:{point_index}')
  512. # point_x =point_index % w
  513. # point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
  514. # print(f'point_x:{point_x}, point_y:{point_y}')
  515. # point_preds[i, :, 0] = point_x
  516. # point_preds[i, :, 1] = point_y
  517. roi1=rois[i]
  518. result_points = find_max_heat_point_in_edge_centers(non_maximum_suppression(point_roi_map), roi1)
  519. point_preds[i, :]=torch.tensor(result_points)
  520. point_x = [point[0] for point in result_points]
  521. point_y = [point[1] for point in result_points]
  522. point_end_scores[i, :,0] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
  523. return point_preds,point_end_scores
  524. def heatmaps_to_lines(maps, rois):
  525. line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
  526. line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  527. line_maps=maps[:,1]
  528. # line_maps = maps.squeeze(1)
  529. for i in range(len(rois)):
  530. line_roi_map = line_maps[i].unsqueeze(0)
  531. print(f'line_roi_map:{line_roi_map.shape}')
  532. # roi_map_probs = scores_to_probs(roi_map.copy())
  533. w = line_roi_map.shape[1]
  534. flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1)
  535. line_score, line_index = torch.topk(flatten_line_roi_map, k=2)
  536. print(f'line index:{line_index}')
  537. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  538. pos = line_index
  539. line_x = pos % w
  540. line_y = torch.div(pos - line_x, w, rounding_mode="floor")
  541. line_preds[i, 0, :] = line_x
  542. line_preds[i, 1, :] = line_y
  543. line_preds[i, 2, :] = 1
  544. line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x]
  545. return line_preds.permute(0, 2, 1), line_end_scores
  546. # 显示热图的函数
  547. def show_heatmap(heatmap, title="Heatmap"):
  548. """
  549. 使用 matplotlib 显示热图。
  550. Args:
  551. heatmap (Tensor): 要显示的热图张量
  552. title (str): 图表标题
  553. """
  554. # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组
  555. if heatmap.is_cuda:
  556. heatmap = heatmap.cpu().numpy()
  557. else:
  558. heatmap = heatmap.numpy()
  559. plt.imshow(heatmap, cmap='hot', interpolation='nearest')
  560. plt.colorbar()
  561. plt.title(title)
  562. plt.show()
  563. def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
  564. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  565. N, K, H, W = line_logits.shape
  566. len_proposals = len(proposals)
  567. print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals},line_matched_idxs:{line_matched_idxs}')
  568. if H != W:
  569. raise ValueError(
  570. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  571. )
  572. discretization_size = H
  573. heatmaps = []
  574. gs_heatmaps = []
  575. valid = []
  576. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
  577. print(f'line_proposals_per_image:{proposals_per_image.shape}')
  578. print(f'gt_lines:{gt_lines}')
  579. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  580. kp = gt_kp_in_image[midx]
  581. gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
  582. gs_heatmaps.append(gs_heatmaps_per_img)
  583. # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
  584. # heatmaps.append(heatmaps_per_image.view(-1))
  585. # valid.append(valid_per_image.view(-1))
  586. # line_targets = torch.cat(heatmaps, dim=0)
  587. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  588. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  589. # print(f'line_targets:{line_targets.shape},{line_targets}')
  590. # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  591. # valid = torch.where(valid)[0]
  592. # print(f' line_targets[valid]:{line_targets[valid]}')
  593. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  594. # accept empty tensors, so handle it sepaartely
  595. # if line_targets.numel() == 0 or len(valid) == 0:
  596. # return line_logits.sum() * 0
  597. # line_logits = line_logits.view(N * K, H * W)
  598. # print(f'line_logits[valid]:{line_logits[valid].shape}')
  599. print(f'loss1 line_logits:{line_logits.shape}')
  600. line_logits = line_logits[:,1,:,:]
  601. # line_logits = line_logits.squeeze(1)
  602. print(f'loss2 line_logits:{line_logits.shape}')
  603. # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
  604. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  605. return line_loss
  606. def compute_mask_loss(feature_logits, proposals, gt_, pos_matched_idxs):
  607. print(f'compute_arc_loss:{feature_logits.shape}')
  608. N, K, H, W = feature_logits.shape
  609. len_proposals = len(proposals)
  610. empty_count = 0
  611. non_empty_count = 0
  612. for prop in proposals:
  613. if prop.shape[0] == 0:
  614. empty_count += 1
  615. else:
  616. non_empty_count += 1
  617. print(f"Empty proposals count: {empty_count}")
  618. print(f"Non-empty proposals count: {non_empty_count}")
  619. print(f'starte to compute_point_loss')
  620. print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}')
  621. if H != W:
  622. raise ValueError(
  623. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  624. )
  625. discretization_size = H
  626. gs_heatmaps = []
  627. # print(f'point_matched_idxs:{point_matched_idxs}')
  628. print(f'gt_masks:{gt_[0].shape}')
  629. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs):
  630. # [
  631. # (Tensor(38, 4), Tensor(1, 57, 2), Tensor(38, 1)),
  632. # (Tensor(65, 4), Tensor(1, 74, 2), Tensor(65, 1))
  633. # ]
  634. print(f'proposals_per_image:{proposals_per_image.shape}')
  635. kp = gt_kp_in_image[midx]
  636. t_h, t_w = kp.shape[-2:]
  637. print(f't_h:{t_h}, t_w:{t_w}')
  638. print(f'gt_kp_in_image:{gt_kp_in_image.shape}')
  639. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  640. gs_heatmaps_per_img = align_masks(kp, proposals_per_image, discretization_size)
  641. gs_heatmaps.append(gs_heatmaps_per_img)
  642. if len(gs_heatmaps)>0:
  643. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  644. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}')
  645. line_logits = feature_logits.squeeze(1)
  646. print(f'mask shape:{line_logits.shape}')
  647. # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
  648. # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  649. line_loss=combined_loss(line_logits, gs_heatmaps)
  650. else:
  651. line_loss=100
  652. print("d")
  653. return line_loss
  654. def align_masks(keypoints, rois, heatmap_size):
  655. print(f'rois:{rois.shape}')
  656. print(f'heatmap_size:{heatmap_size}')
  657. print(f'keypoints.shape:{keypoints.shape}')
  658. # batch_size, num_keypoints, _ = keypoints.shape
  659. t_h, t_w = keypoints.shape[-2:]
  660. scale=heatmap_size/t_w
  661. print(f'scale:{scale}')
  662. x = keypoints[..., 0]*scale
  663. y = keypoints[..., 1]*scale
  664. x = x.unsqueeze(1)
  665. y = y.unsqueeze(1)
  666. num_points=x.shape[2]
  667. print(f'num_points:{num_points}')
  668. mask_4d = keypoints.unsqueeze(1).float()
  669. resized_mask = F.interpolate(
  670. mask_4d,
  671. size = (heatmap_size, heatmap_size),
  672. mode = 'bilinear',
  673. align_corners = False
  674. ).squeeze(1) # [B,heatmap_size,heatmap_size]
  675. # plt.imshow(resized_mask[0].cpu())
  676. # plt.show()
  677. print(f'resized_mask:{resized_mask.shape}')
  678. return resized_mask
  679. def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
  680. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  681. N, K, H, W = line_logits.shape
  682. len_proposals = len(proposals)
  683. empty_count = 0
  684. non_empty_count = 0
  685. for prop in proposals:
  686. if prop.shape[0] == 0:
  687. empty_count += 1
  688. else:
  689. non_empty_count += 1
  690. print(f"Empty proposals count: {empty_count}")
  691. print(f"Non-empty proposals count: {non_empty_count}")
  692. print(f'starte to compute_point_loss')
  693. print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
  694. if H != W:
  695. raise ValueError(
  696. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  697. )
  698. discretization_size = H
  699. gs_heatmaps = []
  700. # print(f'point_matched_idxs:{point_matched_idxs}')
  701. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs):
  702. print(f'proposals_per_image:{proposals_per_image.shape}')
  703. kp = gt_kp_in_image[midx]
  704. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  705. gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size)
  706. gs_heatmaps.append(gs_heatmaps_per_img)
  707. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  708. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  709. line_logits = line_logits[:,0]
  710. print(f'single_point_logits:{line_logits.shape}')
  711. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  712. return line_loss
  713. def compute_circle_loss(circle_logits, proposals, gt_circles, circle_matched_idxs):
  714. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  715. N, K, H, W = circle_logits.shape
  716. len_proposals = len(proposals)
  717. empty_count = 0
  718. non_empty_count = 0
  719. for prop in proposals:
  720. if prop.shape[0] == 0:
  721. empty_count += 1
  722. else:
  723. non_empty_count += 1
  724. print(f"Empty proposals count: {empty_count}")
  725. print(f"Non-empty proposals count: {non_empty_count}")
  726. print(f'starte to compute_circle_loss')
  727. print(f'compute_circle_loss circle_logits.shape:{circle_logits.shape},len_proposals:{len_proposals}')
  728. if H != W:
  729. raise ValueError(
  730. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  731. )
  732. discretization_size = H
  733. gs_heatmaps = []
  734. # print(f'point_matched_idxs:{point_matched_idxs}')
  735. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_circles, circle_matched_idxs):
  736. print(f'proposals_per_image:{proposals_per_image.shape}')
  737. kp = gt_kp_in_image[midx]
  738. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  739. gs_heatmaps_per_img = points_to_heatmap(kp, proposals_per_image,num_points=4, heatmap_size=discretization_size)
  740. gs_heatmaps.append(gs_heatmaps_per_img)
  741. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  742. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{circle_logits.squeeze(1).shape}')
  743. circle_logits = circle_logits[:, 0]
  744. print(f'circle_logits:{circle_logits.shape}')
  745. circle_loss = F.cross_entropy(circle_logits, gs_heatmaps)
  746. return circle_loss
  747. def lines_to_boxes(lines, img_size=511):
  748. """
  749. 输入:
  750. lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
  751. img_size: int,图像尺寸,用于 clamp 边界
  752. 输出:
  753. boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
  754. """
  755. # 提取所有线段的两个端点
  756. p1 = lines[:, 0] # (N, 2)
  757. p2 = lines[:, 1] # (N, 2)
  758. # 每条线段的 x 和 y 坐标
  759. x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1) # (N, 2)
  760. y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1) # (N, 2)
  761. # 计算包围盒边界
  762. x_min = x_coords.min(dim=1).values
  763. y_min = y_coords.min(dim=1).values
  764. x_max = x_coords.max(dim=1).values
  765. y_max = y_coords.max(dim=1).values
  766. # 扩展边界并限制在图像范围内
  767. x_min = (x_min - 1).clamp(min=0, max=img_size)
  768. y_min = (y_min - 1).clamp(min=0, max=img_size)
  769. x_max = (x_max + 1).clamp(min=0, max=img_size)
  770. y_max = (y_max + 1).clamp(min=0, max=img_size)
  771. # 合成包围盒
  772. boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1) # (N, 4)
  773. return boxes
  774. def box_iou_pairwise(box1, box2):
  775. """
  776. 输入:
  777. box1: shape (N, 4)
  778. box2: shape (M, 4)
  779. 输出:
  780. ious: shape (min(N, M), ), 只计算 i = j 的配对
  781. """
  782. N = min(len(box1), len(box2))
  783. lt = torch.max(box1[:N, :2], box2[:N, :2]) # 左上角
  784. rb = torch.min(box1[:N, 2:], box2[:N, 2:]) # 右下角
  785. wh = (rb - lt).clamp(min=0) # 宽高
  786. inter_area = wh[:, 0] * wh[:, 1] # 交集面积
  787. area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
  788. area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
  789. union_area = area1 + area2 - inter_area
  790. ious = inter_area / (union_area + 1e-6)
  791. return ious
  792. def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0):
  793. """
  794. Args:
  795. x: [N,1,H,W] 热力图
  796. boxes: [N,4] 框坐标
  797. gt_lines: [N,2,3] GT线段(含可见性)
  798. matched_idx: 匹配 index
  799. img_size: 图像尺寸
  800. alpha: IoU 损失权重
  801. beta: 长度损失权重
  802. gamma: 方向角度损失权重
  803. """
  804. losses = []
  805. boxes_per_image = [box.size(0) for box in boxes]
  806. x2 = x.split(boxes_per_image, dim=0)
  807. for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
  808. p_prob, _ = heatmaps_to_lines(xx, bb)
  809. pred_lines = p_prob
  810. gt_line_points = gt_line[mid]
  811. if len(pred_lines) == 0 or len(gt_line_points) == 0:
  812. continue
  813. # IoU 损失
  814. pred_boxes = lines_to_boxes(pred_lines, img_size)
  815. gt_boxes = lines_to_boxes(gt_line_points, img_size)
  816. ious = box_iou_pairwise(pred_boxes, gt_boxes)
  817. iou_loss = 1.0 - ious # [N]
  818. # 长度损失
  819. pred_len = line_length(pred_lines)
  820. gt_len = line_length(gt_line_points)
  821. length_diff = F.l1_loss(pred_len, gt_len, reduction='none') # [N]
  822. # 方向角度损失
  823. pred_dir = line_direction(pred_lines)
  824. gt_dir = line_direction(gt_line_points)
  825. ang_loss = angle_loss_cosine(pred_dir, gt_dir) # [N]
  826. # 归一化每一项损失
  827. norm_iou = normalize_tensor(iou_loss)
  828. norm_len = normalize_tensor(length_diff)
  829. norm_ang = normalize_tensor(ang_loss)
  830. total = alpha * norm_iou + beta * norm_len + gamma * norm_ang
  831. losses.append(total)
  832. if not losses:
  833. return None
  834. return torch.mean(torch.cat(losses))
  835. def point_inference(x, point_boxes):
  836. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  837. points_probs = []
  838. points_scores = []
  839. boxes_per_image = [box.size(0) for box in point_boxes]
  840. x2 = x.split(boxes_per_image, dim=0)
  841. for xx, bb in zip(x2, point_boxes):
  842. point_prob,point_scores = heatmaps_to_points(xx, bb,num_points=1)
  843. points_probs.append(point_prob.unsqueeze(1))
  844. points_scores.append(point_scores)
  845. return points_probs,points_scores
  846. def circle_inference(x, point_boxes):
  847. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  848. points_probs = []
  849. points_scores = []
  850. boxes_per_image = [box.size(0) for box in point_boxes]
  851. x2 = x.split(boxes_per_image, dim=0)
  852. for xx, bb in zip(x2, point_boxes):
  853. point_prob,point_scores = heatmaps_to_circle_points(xx, bb,num_points=4)
  854. points_probs.append(point_prob.unsqueeze(1))
  855. points_scores.append(point_scores)
  856. return points_probs,points_scores
  857. def line_inference(x, line_boxes):
  858. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  859. lines_probs = []
  860. lines_scores = []
  861. boxes_per_image = [box.size(0) for box in line_boxes]
  862. x2 = x.split(boxes_per_image, dim=0)
  863. # x2:tuple 2 x2[0]:[1,3,1024,1024]
  864. # line_box: list:2 [1,4] [1.4] fasterrcnn kuang
  865. for xx, bb in zip(x2, line_boxes):
  866. line_prob, line_scores, = heatmaps_to_lines(xx, bb)
  867. lines_probs.append(line_prob)
  868. lines_scores.append(line_scores)
  869. return lines_probs, lines_scores
  870. def arc_inference(x, arc_boxes,th):
  871. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  872. points_probs = []
  873. points_scores = []
  874. print(f'arc_boxes:{len(arc_boxes)}')
  875. boxes_per_image = [box.size(0) for box in arc_boxes]
  876. print(f'arc boxes_per_image:{boxes_per_image}')
  877. x2 = x.split(boxes_per_image, dim=0)
  878. for xx, bb in zip(x2, arc_boxes):
  879. point_prob,point_scores = heatmaps_to_arc(xx, bb)
  880. points_probs.append(point_prob.unsqueeze(1))
  881. points_scores.append(point_scores)
  882. points_probs_tensor=torch.cat(points_probs)
  883. print(f'points_probs shape:{points_probs_tensor.shape}')
  884. feature_logits = x
  885. batch_size = feature_logits.shape[0]
  886. num_proposals = len(arc_boxes[0])
  887. results = [[torch.empty(0, 2) for _ in range(num_proposals)] for _ in range(batch_size)]
  888. proposals_list = arc_boxes[0] # [[tensor(...)]]
  889. for proposal_idx, proposal in enumerate(proposals_list):
  890. coords = proposal.tolist()
  891. x1, y1, x2, y2 = map(int, coords)
  892. x1 = max(0, x1)
  893. y1 = max(0, y1)
  894. x2 = min(feature_logits.shape[3], x2)
  895. y2 = min(feature_logits.shape[2], y2)
  896. for batch_idx in range(batch_size):
  897. region = feature_logits[batch_idx, :, y1:y2, x1:x2]
  898. mask = region > th
  899. coords = torch.nonzero(mask)
  900. if coords.numel() > 0:
  901. # 取 (y, x),然后转换为全局坐标 (x, y)
  902. local_coords = coords[:, [2, 1]] # (x, y)
  903. local_coords[:, 0] += x1
  904. local_coords[:, 1] += y1
  905. results[batch_idx][proposal_idx] = local_coords
  906. print(f're:{results}')
  907. return points_probs,points_scores,results
  908. import torch.nn.functional as F
  909. def heatmaps_to_arc(maps, rois, threshold=0.5, output_size=(128, 128)):
  910. """
  911. Args:
  912. maps: [N, 3, H, W] - full heatmaps
  913. rois: [N, 4] - bounding boxes
  914. threshold: float - binarization threshold
  915. output_size: resized size for uniform NMS
  916. Returns:
  917. masks: [N, 1, H, W] - binary mask aligned with input map
  918. scores: [N, 1] - count of non-zero pixels in each mask
  919. """
  920. N, _, H, W = maps.shape
  921. masks = torch.zeros((N, 1, H, W), dtype=torch.float32, device=maps.device)
  922. scores = torch.zeros((N, 1), dtype=torch.float32, device=maps.device)
  923. point_maps = maps[:, 0] # È¡µÚÒ»¸öͨµÀ [N, H, W]
  924. print(f"==> heatmaps_to_arc: maps.shape = {maps.shape}, rois.shape = {rois.shape}")
  925. for i in range(N):
  926. x1, y1, x2, y2 = rois[i].long()
  927. x1 = x1.clamp(0, W - 1)
  928. x2 = x2.clamp(0, W - 1)
  929. y1 = y1.clamp(0, H - 1)
  930. y2 = y2.clamp(0, H - 1)
  931. print(f"[{i}] roi: ({x1.item()}, {y1.item()}, {x2.item()}, {y2.item()})")
  932. if x2 <= x1 or y2 <= y1:
  933. print(f" Skipped invalid ROI at index {i}")
  934. continue
  935. roi_map = point_maps[i, y1:y2, x1:x2] # [h, w]
  936. print(f" roi_map.shape: {roi_map.shape}")
  937. if roi_map.numel() == 0:
  938. print(f" Skipped empty ROI at index {i}")
  939. continue
  940. # resize to uniform size
  941. roi_map_resized = F.interpolate(
  942. roi_map.unsqueeze(0).unsqueeze(0),
  943. size=output_size,
  944. mode='bilinear',
  945. align_corners=False
  946. ) # [1, 1, H, W]
  947. print(f" roi_map_resized.shape: {roi_map_resized.shape}")
  948. # NMS + threshold
  949. # nms_roi = non_maximum_suppression(roi_map_resized) # shape: [1, H, W]
  950. nms_roi = torch.sigmoid(roi_map_resized)
  951. bin_mask = (nms_roi >= threshold).float() # shape: [1, H, W]
  952. print(f" bin_mask.sum(): {bin_mask.sum().item()}")
  953. # resize back to original roi size
  954. h = int((y2 - y1).item())
  955. w = int((x2 - x1).item())
  956. # È·±£ bin_mask ÊÇ [1, 128, 128]
  957. assert bin_mask.dim() == 4, f"Expected 3D tensor [1, H, W], got {bin_mask.shape}"
  958. # ÉϲÉÑù»Ø ROI ԭʼ´óС
  959. bin_mask_original_size = F.interpolate(
  960. # bin_mask.unsqueeze(0), # ? [1, 1, 128, 128]
  961. bin_mask, # ? [1, 1, 128, 128]
  962. size=(h, w),
  963. mode='bilinear',
  964. align_corners=False
  965. )[0] # ? [1, h, w]
  966. masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size.squeeze()
  967. scores[i] = bin_mask_original_size.sum()
  968. # plt.figure(figsize=(6, 6))
  969. # plt.imshow(masks[i, 0].cpu().numpy(), cmap='gray')
  970. # plt.title(f"Mask {i}, score={scores[i].item():.1f}")
  971. # plt.axis('off')
  972. # plt.show()
  973. print(f" bin_mask_original_size.shape: {bin_mask_original_size.shape}, sum: {scores[i].item()}")
  974. print(f"==> Done. Total valid masks: {(scores > 0).sum().item()} / {N}")
  975. return masks, scores