head_losses.py 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285
  1. import torch
  2. from matplotlib import pyplot as plt
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from torch.cuda import device
  6. class DiceLoss(nn.Module):
  7. def __init__(self, smooth=1.):
  8. super(DiceLoss, self).__init__()
  9. self.smooth = smooth
  10. def forward(self, logits, targets):
  11. probs = torch.sigmoid(logits)
  12. probs = probs.view(-1)
  13. targets = targets.view(-1).float()
  14. intersection = (probs * targets).sum()
  15. dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
  16. return 1. - dice
  17. bce_loss = nn.BCEWithLogitsLoss()
  18. dice_loss = DiceLoss()
  19. def combined_loss(preds, targets, alpha=0.5):
  20. bce = bce_loss(preds, targets)
  21. d = dice_loss(preds, targets)
  22. return alpha * bce + (1 - alpha) * d
  23. def features_align(features, proposals, img_size):
  24. print(f'features_align features:{features.shape},proposals:{len(proposals)}')
  25. align_feat_list = []
  26. for feat, proposals_per_img in zip(features, proposals):
  27. print(f'feature_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
  28. if proposals_per_img.shape[0]>0:
  29. feat = feat.unsqueeze(0)
  30. for proposal in proposals_per_img:
  31. align_feat = torch.zeros_like(feat)
  32. # print(f'align_feat:{align_feat.shape}')
  33. x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
  34. # 将每个proposal框内的部分赋值到align_feats对应位置
  35. align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
  36. align_feat_list.append(align_feat)
  37. # print(f'align_feat_list:{align_feat_list}')
  38. if len(align_feat_list) > 0:
  39. feats_tensor = torch.cat(align_feat_list)
  40. print(f'align features :{feats_tensor.shape}')
  41. else:
  42. feats_tensor = None
  43. return feats_tensor
  44. def normalize_tensor(t):
  45. return (t - t.min()) / (t.max() - t.min() + 1e-6)
  46. def line_length(lines):
  47. """
  48. 计算每条线段的长度
  49. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  50. 返回: [N]
  51. """
  52. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  53. def line_direction(lines):
  54. """
  55. 计算每条线段的单位方向向量
  56. lines: [N, 2, 2]
  57. 返回: [N, 2] 单位方向向量
  58. """
  59. vec = lines[:, 1] - lines[:, 0]
  60. return F.normalize(vec, dim=-1)
  61. def angle_loss_cosine(pred_dir, gt_dir):
  62. """
  63. 使用 cosine similarity 计算方向差异
  64. pred_dir: [N, 2]
  65. gt_dir: [N, 2]
  66. 返回: [N]
  67. """
  68. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  69. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  70. def line_length(lines):
  71. """
  72. 计算每条线段的长度
  73. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  74. 返回: [N]
  75. """
  76. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  77. def line_direction(lines):
  78. """
  79. 计算每条线段的单位方向向量
  80. lines: [N, 2, 2]
  81. 返回: [N, 2] 单位方向向量
  82. """
  83. vec = lines[:, 1] - lines[:, 0]
  84. return F.normalize(vec, dim=-1)
  85. def angle_loss_cosine(pred_dir, gt_dir):
  86. """
  87. 使用 cosine similarity 计算方向差异
  88. pred_dir: [N, 2]
  89. gt_dir: [N, 2]
  90. 返回: [N]
  91. """
  92. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  93. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  94. def single_point_to_heatmap(keypoints, rois, heatmap_size):
  95. # type: (Tensor, Tensor, int) -> Tensor
  96. print(f'rois:{rois.shape}')
  97. print(f'heatmap_size:{heatmap_size}')
  98. print(f'keypoints.shape:{keypoints.shape}')
  99. # batch_size, num_keypoints, _ = keypoints.shape
  100. x = keypoints[..., 0].unsqueeze(1)
  101. y = keypoints[..., 1].unsqueeze(1)
  102. gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
  103. # show_heatmap(gs[0],'target')
  104. all_roi_heatmap = []
  105. for roi, heatmap in zip(rois, gs):
  106. # show_heatmap(heatmap, 'target')
  107. # print(f'heatmap:{heatmap.shape}')
  108. heatmap = heatmap.unsqueeze(0)
  109. x1, y1, x2, y2 = map(int, roi)
  110. roi_heatmap = torch.zeros_like(heatmap)
  111. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  112. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  113. all_roi_heatmap.append(roi_heatmap)
  114. all_roi_heatmap = torch.cat(all_roi_heatmap)
  115. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  116. return all_roi_heatmap
  117. def points_to_heatmap(keypoints, rois,num_points=2, heatmap_size=(512,512)):
  118. # type: (Tensor, Tensor, int) -> Tensor
  119. print(f'rois:{rois.shape}')
  120. print(f'heatmap_size:{heatmap_size}')
  121. print(f'keypoints.shape:{keypoints.shape}')
  122. # batch_size, num_keypoints, _ = keypoints.shape
  123. x = keypoints[..., 0].unsqueeze(1)
  124. y = keypoints[..., 1].unsqueeze(1)
  125. gs = generate_gaussian_heatmaps(x, y,num_points=num_points, heatmap_size=heatmap_size, sigma=2.0)
  126. # show_heatmap(gs[0],'target')
  127. all_roi_heatmap = []
  128. for roi, heatmap in zip(rois, gs):
  129. # show_heatmap(heatmap, 'target')
  130. # print(f'heatmap:{heatmap.shape}')
  131. heatmap = heatmap.unsqueeze(0)
  132. x1, y1, x2, y2 = map(int, roi)
  133. roi_heatmap = torch.zeros_like(heatmap)
  134. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  135. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  136. all_roi_heatmap.append(roi_heatmap)
  137. all_roi_heatmap = torch.cat(all_roi_heatmap)
  138. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  139. return all_roi_heatmap
  140. def line_points_to_heatmap(keypoints, rois, heatmap_size):
  141. # type: (Tensor, Tensor, int) -> Tensor
  142. print(f'rois:{rois.shape}')
  143. print(f'heatmap_size:{heatmap_size}')
  144. print(f'keypoints.shape:{keypoints.shape}')
  145. # batch_size, num_keypoints, _ = keypoints.shape
  146. x = keypoints[..., 0]
  147. y = keypoints[..., 1]
  148. gs = generate_gaussian_heatmaps(x, y, heatmap_size,num_points=2, sigma=1.0)
  149. # show_heatmap(gs[0],'target')
  150. all_roi_heatmap = []
  151. for roi, heatmap in zip(rois, gs):
  152. # print(f'heatmap:{heatmap.shape}')
  153. # show_heatmap(heatmap,'target')
  154. heatmap = heatmap.unsqueeze(0)
  155. x1, y1, x2, y2 = map(int, roi)
  156. roi_heatmap = torch.zeros_like(heatmap)
  157. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  158. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  159. all_roi_heatmap.append(roi_heatmap)
  160. if len(all_roi_heatmap) > 0:
  161. all_roi_heatmap = torch.cat(all_roi_heatmap)
  162. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  163. else:
  164. all_roi_heatmap = None
  165. return all_roi_heatmap
  166. """
  167. 修改适配的原结构的点 转热图,适用于带roi_pool版本的
  168. """
  169. def line_points_to_heatmap_(keypoints, rois, heatmap_size):
  170. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  171. print(f'rois:{rois.shape}')
  172. print(f'heatmap_size:{heatmap_size}')
  173. offset_x = rois[:, 0]
  174. offset_y = rois[:, 1]
  175. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  176. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  177. offset_x = offset_x[:, None]
  178. offset_y = offset_y[:, None]
  179. scale_x = scale_x[:, None]
  180. scale_y = scale_y[:, None]
  181. print(f'keypoints.shape:{keypoints.shape}')
  182. # batch_size, num_keypoints, _ = keypoints.shape
  183. x = keypoints[..., 0]
  184. y = keypoints[..., 1]
  185. # gs=generate_gaussian_heatmaps(x,y,512,1.0)
  186. # print(f'gs_heatmap shape:{gs.shape}')
  187. #
  188. # show_heatmap(gs[0],'target')
  189. x_boundary_inds = x == rois[:, 2][:, None]
  190. y_boundary_inds = y == rois[:, 3][:, None]
  191. x = (x - offset_x) * scale_x
  192. x = x.floor().long()
  193. y = (y - offset_y) * scale_y
  194. y = y.floor().long()
  195. x[x_boundary_inds] = heatmap_size - 1
  196. y[y_boundary_inds] = heatmap_size - 1
  197. # print(f'heatmaps x:{x}')
  198. # print(f'heatmaps y:{y}')
  199. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  200. vis = keypoints[..., 2] > 0
  201. valid = (valid_loc & vis).long()
  202. gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
  203. # show_heatmap(gs_heatmap[0], 'feature')
  204. # print(f'gs_heatmap:{gs_heatmap.shape}')
  205. #
  206. # lin_ind = y * heatmap_size + x
  207. # print(f'lin_ind:{lin_ind.shape}')
  208. # heatmaps = lin_ind * valid
  209. return gs_heatmap
  210. def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  211. """
  212. 为一组点生成并合并高斯热图。
  213. Args:
  214. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  215. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  216. heatmap_size (int): 热图大小 H=W
  217. sigma (float): 高斯核标准差
  218. device (str): 设备类型 ('cpu' or 'cuda')
  219. Returns:
  220. Tensor: 形状为 (H, W) 的合并后的热图
  221. """
  222. assert xs.shape == ys.shape, "x and y must have the same shape"
  223. print(f'xs:{xs.shape}')
  224. xs=xs.squeeze(1)
  225. ys = ys.squeeze(1)
  226. print(f'xs1:{xs.shape}')
  227. N = xs.shape[0]
  228. print(f'N:{N},num_points:{num_points}')
  229. # 创建网格
  230. grid_y, grid_x = torch.meshgrid(
  231. torch.arange(heatmap_size, device=device),
  232. torch.arange(heatmap_size, device=device),
  233. indexing='ij'
  234. )
  235. # print(f'heatmap_size:{heatmap_size}')
  236. # 初始化输出热图
  237. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  238. for i in range(N):
  239. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  240. for j in range(num_points):
  241. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  242. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  243. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  244. # 计算距离平方
  245. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  246. # 计算高斯分布
  247. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  248. heatmap+=heatmap1
  249. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  250. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  251. #
  252. # # 计算距离平方
  253. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  254. #
  255. # # 计算高斯分布
  256. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  257. #
  258. # heatmap = heatmap1 + heatmap2
  259. # 将当前热图累加到结果中
  260. combined_heatmap[i] = heatmap
  261. return combined_heatmap
  262. def generate_mask_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  263. """
  264. 为一组点生成并合并高斯热图。
  265. Args:
  266. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  267. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  268. heatmap_size (int): 热图大小 H=W
  269. sigma (float): 高斯核标准差
  270. device (str): 设备类型 ('cpu' or 'cuda')
  271. Returns:
  272. Tensor: 形状为 (H, W) 的合并后的热图
  273. """
  274. assert xs.shape == ys.shape, "x and y must have the same shape"
  275. print(f'xs:{xs.shape}')
  276. xs=xs.squeeze(1)
  277. ys = ys.squeeze(1)
  278. print(f'xs1:{xs.shape}')
  279. N = xs.shape[0]
  280. print(f'N:{N},num_points:{num_points}')
  281. # 创建网格
  282. grid_y, grid_x = torch.meshgrid(
  283. torch.arange(heatmap_size, device=device),
  284. torch.arange(heatmap_size, device=device),
  285. indexing='ij'
  286. )
  287. # print(f'heatmap_size:{heatmap_size}')
  288. # 初始化输出热图
  289. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  290. for i in range(N):
  291. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  292. for j in range(num_points):
  293. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  294. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  295. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  296. # 计算距离平方
  297. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  298. # 计算高斯分布
  299. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  300. heatmap+=heatmap1
  301. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  302. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  303. #
  304. # # 计算距离平方
  305. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  306. #
  307. # # 计算高斯分布
  308. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  309. #
  310. # heatmap = heatmap1 + heatmap2
  311. # 将当前热图累加到结果中
  312. combined_heatmap[i] = heatmap
  313. return combined_heatmap
  314. def non_maximum_suppression(a):
  315. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  316. mask = (a == ap).float().clamp(min=0.0)
  317. return a * mask
  318. def heatmaps_to_points(maps, rois,num_points=2):
  319. point_preds = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  320. point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device)
  321. print(f'heatmaps_to_lines:{maps.shape}')
  322. point_maps=maps[:,0]
  323. print(f'point_map:{point_maps.shape}')
  324. for i in range(len(rois)):
  325. point_roi_map = point_maps[i].unsqueeze(0)
  326. print(f'point_roi_map:{point_roi_map.shape}')
  327. # roi_map_probs = scores_to_probs(roi_map.copy())
  328. w = point_roi_map.shape[2]
  329. flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
  330. point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
  331. print(f'point index:{point_index}')
  332. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  333. point_x =point_index % w
  334. point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
  335. point_preds[i, 0,] = point_x
  336. point_preds[i, 1,] = point_y
  337. point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
  338. return point_preds,point_end_scores
  339. # 分4块
  340. def find_max_heat_point_in_each_part(feature_map, box):
  341. """
  342. 在给定的特征图上,根据box中心点往上移3,往右移3作为新的中心点,
  343. 并将特征图划分为4个部分,之后在每个部分中找到热度值最大的点。
  344. Args:
  345. feature_map (torch.Tensor): 形状为 [C, H, W] 的特征图
  346. box (torch.Tensor): 形状为 [4] 的边界框 [x_min, y_min, x_max, y_max]
  347. Returns:
  348. list: 每个区域中热度最高的点的位置和其对应的热度值 [(y1, x1, heat1), ..., (y4, x4, heat4)]
  349. """
  350. device = feature_map.device
  351. C, H, W = feature_map.shape
  352. # 计算box的中心点(cx, cy)
  353. cx = (box[0] + box[2]) // 2
  354. cy = (box[1] + box[3]) // 2
  355. # 偏移中心点
  356. new_cx = min(max(cx + 3, 0), W - 1) # 向右移3
  357. new_cy = min(max(cy - 3, 0), H - 1) # 向上移3
  358. # 创建坐标网格
  359. y_coords, x_coords = torch.meshgrid(
  360. torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij'
  361. )
  362. # 划分四个区域
  363. mask_q1 = (y_coords < new_cy) & (x_coords < new_cx) # 左上
  364. mask_q2 = (y_coords < new_cy) & (x_coords >= new_cx) # 右上
  365. mask_q3 = (y_coords >= new_cy) & (x_coords < new_cx) # 左下
  366. mask_q4 = (y_coords >= new_cy) & (x_coords >= new_cx) # 右下
  367. # def process_region(mask):
  368. # region = feature_map[:, :, mask].squeeze()
  369. # if len(region.shape) == 0: # 如果区域为空,则跳过
  370. # return None, None
  371. # # 找到最大热度值的点及其位置
  372. # (y, x), heat_val = non_maximum_suppression(region[0])
  373. # # 将相对坐标转换回全局坐标
  374. # y_global = y + torch.where(mask)[0].min().item()
  375. # x_global = x + torch.where(mask)[1].min().item()
  376. # return (y_global, x_global), heat_val
  377. #
  378. # results = []
  379. # for mask in [mask_q1, mask_q2, mask_q3, mask_q4]:
  380. # point, heat_val = process_region(mask)
  381. # if point is not None:
  382. # # results.append((point[0], point[1], heat_val))
  383. # results.append((point[0], point[1]))
  384. # else:
  385. # results.append(None)
  386. masks = [mask_q1, mask_q2, mask_q3, mask_q4]
  387. results = []
  388. # 假设使用第一个通道作为热力图
  389. heatmap = feature_map[0] # [H, W]
  390. def process_region(mask):
  391. # 应用 mask,只保留该区域
  392. masked_heatmap = heatmap.clone() # 复制以避免修改原数据
  393. masked_heatmap[~mask] = 0 # 非区域置0
  394. def non_maximum_suppression_2d(heatmap, kernel_size=3):
  395. """
  396. 对 2D 热力图做非极大值抑制,保留局部最大值点。
  397. Args:
  398. heatmap (torch.Tensor): [H, W],输入热力图
  399. kernel_size (int): 池化窗口大小,用于比较是否为局部最大值
  400. Returns:
  401. torch.Tensor: 与 heatmap 同形状的 mask,局部最大值位置为 True
  402. """
  403. pad = (kernel_size - 1) // 2
  404. max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
  405. maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
  406. # 局部最大值且值大于0
  407. peaks = (heatmap == maxima) & (heatmap > 0)
  408. return peaks
  409. # 1. 先做 NMS 得到候选局部极大值点
  410. nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3) # [H, W] bool
  411. candidate_peaks = masked_heatmap * nms_mask.float() # 只保留 NMS 后的峰值
  412. # 2. 找出所有候选点中值最大的一个
  413. if candidate_peaks.max() <= 0:
  414. return None
  415. # 找到最大值的位置
  416. max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
  417. y, x = divmod(max_idx.item(), W)
  418. return (x, y) # 返回 (y, x)
  419. for mask in masks:
  420. point = process_region(mask)
  421. results.append(point)
  422. return results
  423. def non_maximum_suppression_2d(heatmap, kernel_size=3):
  424. pad = (kernel_size - 1) // 2
  425. max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
  426. maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
  427. peaks = (heatmap == maxima) & (heatmap > 0)
  428. return peaks
  429. def find_max_heat_point_in_edge_centers(feature_map, box):
  430. device = feature_map.device
  431. C, H, W = feature_map.shape
  432. # ¼ÆËã box ÖÐÐÄ
  433. cx = (box[0] + box[2]) / 2
  434. cy = (box[1] + box[3]) / 2
  435. # ¸ù¾Ý box ¿í¸ß¼ÆËã¾Å¹¬¸ñ·Ö½çÏß
  436. box_width = box[2] - box[0]
  437. box_height = box[3] - box[1]
  438. x_left = cx - box_width / 6
  439. x_right = cx + box_width / 6
  440. y_top = cy - box_height / 6
  441. y_bottom = cy + box_height / 6
  442. # ´´½¨Íø¸ñ
  443. y_coords, x_coords = torch.meshgrid(
  444. torch.arange(H, device=device),
  445. torch.arange(W, device=device),
  446. indexing='ij'
  447. )
  448. # ¶¨ÒåËĸö¡°±ßÖС±ÇøÓòµÄ mask
  449. mask1 = (x_coords < x_left) & (y_coords < y_top)
  450. mask_top_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords < y_top)
  451. mask3 = (x_coords >= x_right) & (y_coords < y_top)
  452. mask_left_middle = (x_coords < x_left) & (y_coords >= y_top) & (y_coords < y_bottom)
  453. mask_right_middle = (x_coords >= x_right) & (y_coords >= y_top) & (y_coords < y_bottom)
  454. mask4 = (x_coords < x_left) & (y_coords >= y_bottom)
  455. mask_bottom_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords >= y_bottom)
  456. mask_right_bottom = (x_coords >= x_right) & (y_coords >= y_bottom)
  457. # masks = [
  458. # # mask1,
  459. # mask_top_middle,
  460. # # mask3,
  461. # mask_left_middle,
  462. # mask_right_middle,
  463. # # mask4,
  464. # mask_bottom_middle,
  465. # mask_right_bottom
  466. # ]
  467. masks = [
  468. mask_top_middle,
  469. mask_right_middle,
  470. mask_bottom_middle,
  471. mask_left_middle
  472. ]
  473. # ʹÓõÚÒ»¸öͨµÀ×÷ΪÈÈÁ¦Í¼
  474. heatmap = feature_map[0] # [H, W]
  475. results = []
  476. for mask in masks:
  477. masked_heatmap = heatmap.clone()
  478. masked_heatmap[~mask] = 0 # ·ÇÄ¿±êÇøÓòÖà 0
  479. # # NMS ÒÖÖÆ
  480. # nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3)
  481. # candidate_peaks = masked_heatmap * nms_mask.float()
  482. #
  483. # if candidate_peaks.max() <= 0:
  484. # results.append(None)
  485. # continue
  486. #
  487. # # ÕÒ×î´óֵλÖÃ
  488. # max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
  489. # y, x = divmod(max_idx.item(), W)
  490. flatten_point_roi_map = masked_heatmap.reshape(1, -1)
  491. point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
  492. point_x =point_index % W
  493. point_y = torch.div(point_index - point_x, W, rounding_mode="floor")
  494. results.append((point_x, point_y))
  495. return results # [(y_top, x_top), (y_right, x_right), (y_bottom, x_bottom), (y_left, x_left)]
  496. def heatmaps_to_circle_points(maps, rois,num_points=2):
  497. point_preds = torch.zeros((len(rois), 4, 2), dtype=torch.float32, device=maps.device)
  498. point_end_scores = torch.zeros((len(rois),4, 1), dtype=torch.float32, device=maps.device)
  499. print(f'rois in heatmaps_to_circle_points:{type(rois), rois.shape}') # <class 'torch.Tensor'>
  500. print(f'heatmaps_to_lines:{maps.shape}')
  501. point_maps=maps[:,0]
  502. print(f'point_map:{point_maps.shape}')
  503. for i in range(len(rois)):
  504. point_roi_map = point_maps[i].unsqueeze(0)
  505. print(f'point_roi_map:{point_roi_map.shape}')
  506. # roi_map_probs = scores_to_probs(roi_map.copy())
  507. # w = point_roi_map.shape[2]
  508. # flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
  509. # print(f'non_maximum_suppression :{non_maximum_suppression(point_roi_map).shape}')
  510. # point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
  511. # print(f'point index:{point_index}')
  512. # point_x =point_index % w
  513. # point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
  514. # print(f'point_x:{point_x}, point_y:{point_y}')
  515. # point_preds[i, :, 0] = point_x
  516. # point_preds[i, :, 1] = point_y
  517. roi1=rois[i]
  518. result_points = find_max_heat_point_in_edge_centers(non_maximum_suppression(point_roi_map), roi1)
  519. point_preds[i, :]=torch.tensor(result_points)
  520. point_x = [point[0] for point in result_points]
  521. point_y = [point[1] for point in result_points]
  522. point_end_scores[i, :,0] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
  523. return point_preds,point_end_scores
  524. def heatmaps_to_lines(maps, rois):
  525. line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
  526. line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  527. line_maps=maps[:,1]
  528. # line_maps = maps.squeeze(1)
  529. for i in range(len(rois)):
  530. line_roi_map = line_maps[i].unsqueeze(0)
  531. print(f'line_roi_map:{line_roi_map.shape}')
  532. # roi_map_probs = scores_to_probs(roi_map.copy())
  533. w = line_roi_map.shape[1]
  534. flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1)
  535. line_score, line_index = torch.topk(flatten_line_roi_map, k=2)
  536. print(f'line index:{line_index}')
  537. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  538. pos = line_index
  539. line_x = pos % w
  540. line_y = torch.div(pos - line_x, w, rounding_mode="floor")
  541. line_preds[i, 0, :] = line_x
  542. line_preds[i, 1, :] = line_y
  543. line_preds[i, 2, :] = 1
  544. line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x]
  545. return line_preds.permute(0, 2, 1), line_end_scores
  546. # 显示热图的函数
  547. def show_heatmap(heatmap, title="Heatmap"):
  548. """
  549. 使用 matplotlib 显示热图。
  550. Args:
  551. heatmap (Tensor): 要显示的热图张量
  552. title (str): 图表标题
  553. """
  554. # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组
  555. if heatmap.is_cuda:
  556. heatmap = heatmap.cpu().numpy()
  557. else:
  558. heatmap = heatmap.numpy()
  559. plt.imshow(heatmap, cmap='hot', interpolation='nearest')
  560. plt.colorbar()
  561. plt.title(title)
  562. plt.show()
  563. def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
  564. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  565. N, K, H, W = line_logits.shape
  566. len_proposals = len(proposals)
  567. print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals},line_matched_idxs:{line_matched_idxs}')
  568. if H != W:
  569. raise ValueError(
  570. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  571. )
  572. discretization_size = H
  573. heatmaps = []
  574. gs_heatmaps = []
  575. valid = []
  576. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
  577. print(f'line_proposals_per_image:{proposals_per_image.shape}')
  578. print(f'gt_lines:{gt_lines}')
  579. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  580. kp = gt_kp_in_image[midx]
  581. gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
  582. gs_heatmaps.append(gs_heatmaps_per_img)
  583. # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
  584. # heatmaps.append(heatmaps_per_image.view(-1))
  585. # valid.append(valid_per_image.view(-1))
  586. # line_targets = torch.cat(heatmaps, dim=0)
  587. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  588. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  589. # print(f'line_targets:{line_targets.shape},{line_targets}')
  590. # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  591. # valid = torch.where(valid)[0]
  592. # print(f' line_targets[valid]:{line_targets[valid]}')
  593. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  594. # accept empty tensors, so handle it sepaartely
  595. # if line_targets.numel() == 0 or len(valid) == 0:
  596. # return line_logits.sum() * 0
  597. # line_logits = line_logits.view(N * K, H * W)
  598. # print(f'line_logits[valid]:{line_logits[valid].shape}')
  599. print(f'loss1 line_logits:{line_logits.shape}')
  600. line_logits = line_logits[:,1,:,:]
  601. # line_logits = line_logits.squeeze(1)
  602. print(f'loss2 line_logits:{line_logits.shape}')
  603. # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
  604. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  605. return line_loss
  606. def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
  607. print(f'compute_arc_loss:{feature_logits.shape}')
  608. N, K, H, W = feature_logits.shape
  609. len_proposals = len(proposals)
  610. empty_count = 0
  611. non_empty_count = 0
  612. for prop in proposals:
  613. if prop.shape[0] == 0:
  614. empty_count += 1
  615. else:
  616. non_empty_count += 1
  617. print(f"Empty proposals count: {empty_count}")
  618. print(f"Non-empty proposals count: {non_empty_count}")
  619. print(f'starte to compute_point_loss')
  620. print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}')
  621. if H != W:
  622. raise ValueError(
  623. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  624. )
  625. discretization_size = H
  626. gs_heatmaps = []
  627. # print(f'point_matched_idxs:{point_matched_idxs}')
  628. print(f'gt_masks:{gt_[0].shape}')
  629. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs):
  630. # [
  631. # (Tensor(38, 4), Tensor(1, 57, 2), Tensor(38, 1)),
  632. # (Tensor(65, 4), Tensor(1, 74, 2), Tensor(65, 1))
  633. # ]
  634. print(f'proposals_per_image:{proposals_per_image.shape}')
  635. kp = gt_kp_in_image[midx]
  636. t_h, t_w = kp.shape[-2:]
  637. print(f't_h:{t_h}, t_w:{t_w}')
  638. print(f'gt_kp_in_image:{gt_kp_in_image.shape}')
  639. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  640. gs_heatmaps_per_img = arc_points_to_heatmap(kp, proposals_per_image, discretization_size)
  641. gs_heatmaps.append(gs_heatmaps_per_img)
  642. if len(gs_heatmaps)>0:
  643. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  644. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}')
  645. line_logits = feature_logits.squeeze(1)
  646. print(f'single_point_logits:{line_logits.shape}')
  647. # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
  648. # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  649. line_loss=combined_loss(line_logits, gs_heatmaps)
  650. else:
  651. line_loss=100
  652. print("d")
  653. return line_loss
  654. def arc_points_to_heatmap(keypoints, rois, heatmap_size):
  655. print(f'rois:{rois.shape}')
  656. print(f'heatmap_size:{heatmap_size}')
  657. print(f'keypoints.shape:{keypoints.shape}')
  658. # batch_size, num_keypoints, _ = keypoints.shape
  659. t_h, t_w = keypoints.shape[-2:]
  660. scale=heatmap_size/t_w
  661. print(f'scale:{scale}')
  662. x = keypoints[..., 0]*scale
  663. y = keypoints[..., 1]*scale
  664. x = x.unsqueeze(1)
  665. y = y.unsqueeze(1)
  666. num_points=x.shape[2]
  667. print(f'num_points:{num_points}')
  668. gs = generate_mask_gaussian_heatmaps(x, y, num_points=num_points, heatmap_size=heatmap_size, sigma=10)
  669. print(f'gs max :{gs.max()},gs.shape:{gs.shape}')
  670. # show_heatmap(gs[0],'target')
  671. all_roi_heatmap = []
  672. for roi, heatmap in zip(rois, gs):
  673. show_heatmap(heatmap, 'target')
  674. print(f'heatmap.shape:{heatmap.shape}')
  675. heatmap = heatmap.unsqueeze(0)
  676. x1, y1, x2, y2 = map(int, roi)
  677. roi_heatmap = torch.zeros_like(heatmap)
  678. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  679. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  680. all_roi_heatmap.append(roi_heatmap)
  681. all_roi_heatmap = torch.cat(all_roi_heatmap)
  682. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  683. return all_roi_heatmap
  684. def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
  685. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  686. N, K, H, W = line_logits.shape
  687. len_proposals = len(proposals)
  688. empty_count = 0
  689. non_empty_count = 0
  690. for prop in proposals:
  691. if prop.shape[0] == 0:
  692. empty_count += 1
  693. else:
  694. non_empty_count += 1
  695. print(f"Empty proposals count: {empty_count}")
  696. print(f"Non-empty proposals count: {non_empty_count}")
  697. print(f'starte to compute_point_loss')
  698. print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
  699. if H != W:
  700. raise ValueError(
  701. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  702. )
  703. discretization_size = H
  704. gs_heatmaps = []
  705. # print(f'point_matched_idxs:{point_matched_idxs}')
  706. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs):
  707. print(f'proposals_per_image:{proposals_per_image.shape}')
  708. kp = gt_kp_in_image[midx]
  709. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  710. gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size)
  711. gs_heatmaps.append(gs_heatmaps_per_img)
  712. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  713. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  714. line_logits = line_logits[:,0]
  715. print(f'single_point_logits:{line_logits.shape}')
  716. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  717. return line_loss
  718. def compute_circle_loss(circle_logits, proposals, gt_circles, circle_matched_idxs):
  719. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  720. N, K, H, W = circle_logits.shape
  721. len_proposals = len(proposals)
  722. empty_count = 0
  723. non_empty_count = 0
  724. for prop in proposals:
  725. if prop.shape[0] == 0:
  726. empty_count += 1
  727. else:
  728. non_empty_count += 1
  729. print(f"Empty proposals count: {empty_count}")
  730. print(f"Non-empty proposals count: {non_empty_count}")
  731. print(f'starte to compute_circle_loss')
  732. print(f'compute_circle_loss circle_logits.shape:{circle_logits.shape},len_proposals:{len_proposals}')
  733. if H != W:
  734. raise ValueError(
  735. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  736. )
  737. discretization_size = H
  738. gs_heatmaps = []
  739. # print(f'point_matched_idxs:{point_matched_idxs}')
  740. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_circles, circle_matched_idxs):
  741. print(f'proposals_per_image:{proposals_per_image.shape}')
  742. kp = gt_kp_in_image[midx]
  743. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  744. gs_heatmaps_per_img = points_to_heatmap(kp, proposals_per_image,num_points=4, heatmap_size=discretization_size)
  745. gs_heatmaps.append(gs_heatmaps_per_img)
  746. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  747. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{circle_logits.squeeze(1).shape}')
  748. circle_logits = circle_logits[:, 0]
  749. print(f'circle_logits:{circle_logits.shape}')
  750. circle_loss = F.cross_entropy(circle_logits, gs_heatmaps)
  751. return circle_loss
  752. def lines_to_boxes(lines, img_size=511):
  753. """
  754. 输入:
  755. lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
  756. img_size: int,图像尺寸,用于 clamp 边界
  757. 输出:
  758. boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
  759. """
  760. # 提取所有线段的两个端点
  761. p1 = lines[:, 0] # (N, 2)
  762. p2 = lines[:, 1] # (N, 2)
  763. # 每条线段的 x 和 y 坐标
  764. x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1) # (N, 2)
  765. y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1) # (N, 2)
  766. # 计算包围盒边界
  767. x_min = x_coords.min(dim=1).values
  768. y_min = y_coords.min(dim=1).values
  769. x_max = x_coords.max(dim=1).values
  770. y_max = y_coords.max(dim=1).values
  771. # 扩展边界并限制在图像范围内
  772. x_min = (x_min - 1).clamp(min=0, max=img_size)
  773. y_min = (y_min - 1).clamp(min=0, max=img_size)
  774. x_max = (x_max + 1).clamp(min=0, max=img_size)
  775. y_max = (y_max + 1).clamp(min=0, max=img_size)
  776. # 合成包围盒
  777. boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1) # (N, 4)
  778. return boxes
  779. def box_iou_pairwise(box1, box2):
  780. """
  781. 输入:
  782. box1: shape (N, 4)
  783. box2: shape (M, 4)
  784. 输出:
  785. ious: shape (min(N, M), ), 只计算 i = j 的配对
  786. """
  787. N = min(len(box1), len(box2))
  788. lt = torch.max(box1[:N, :2], box2[:N, :2]) # 左上角
  789. rb = torch.min(box1[:N, 2:], box2[:N, 2:]) # 右下角
  790. wh = (rb - lt).clamp(min=0) # 宽高
  791. inter_area = wh[:, 0] * wh[:, 1] # 交集面积
  792. area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
  793. area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
  794. union_area = area1 + area2 - inter_area
  795. ious = inter_area / (union_area + 1e-6)
  796. return ious
  797. def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0):
  798. """
  799. Args:
  800. x: [N,1,H,W] 热力图
  801. boxes: [N,4] 框坐标
  802. gt_lines: [N,2,3] GT线段(含可见性)
  803. matched_idx: 匹配 index
  804. img_size: 图像尺寸
  805. alpha: IoU 损失权重
  806. beta: 长度损失权重
  807. gamma: 方向角度损失权重
  808. """
  809. losses = []
  810. boxes_per_image = [box.size(0) for box in boxes]
  811. x2 = x.split(boxes_per_image, dim=0)
  812. for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
  813. p_prob, _ = heatmaps_to_lines(xx, bb)
  814. pred_lines = p_prob
  815. gt_line_points = gt_line[mid]
  816. if len(pred_lines) == 0 or len(gt_line_points) == 0:
  817. continue
  818. # IoU 损失
  819. pred_boxes = lines_to_boxes(pred_lines, img_size)
  820. gt_boxes = lines_to_boxes(gt_line_points, img_size)
  821. ious = box_iou_pairwise(pred_boxes, gt_boxes)
  822. iou_loss = 1.0 - ious # [N]
  823. # 长度损失
  824. pred_len = line_length(pred_lines)
  825. gt_len = line_length(gt_line_points)
  826. length_diff = F.l1_loss(pred_len, gt_len, reduction='none') # [N]
  827. # 方向角度损失
  828. pred_dir = line_direction(pred_lines)
  829. gt_dir = line_direction(gt_line_points)
  830. ang_loss = angle_loss_cosine(pred_dir, gt_dir) # [N]
  831. # 归一化每一项损失
  832. norm_iou = normalize_tensor(iou_loss)
  833. norm_len = normalize_tensor(length_diff)
  834. norm_ang = normalize_tensor(ang_loss)
  835. total = alpha * norm_iou + beta * norm_len + gamma * norm_ang
  836. losses.append(total)
  837. if not losses:
  838. return None
  839. return torch.mean(torch.cat(losses))
  840. def point_inference(x, point_boxes):
  841. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  842. points_probs = []
  843. points_scores = []
  844. boxes_per_image = [box.size(0) for box in point_boxes]
  845. x2 = x.split(boxes_per_image, dim=0)
  846. for xx, bb in zip(x2, point_boxes):
  847. point_prob,point_scores = heatmaps_to_points(xx, bb,num_points=1)
  848. points_probs.append(point_prob.unsqueeze(1))
  849. points_scores.append(point_scores)
  850. return points_probs,points_scores
  851. def circle_inference(x, point_boxes):
  852. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  853. points_probs = []
  854. points_scores = []
  855. boxes_per_image = [box.size(0) for box in point_boxes]
  856. x2 = x.split(boxes_per_image, dim=0)
  857. for xx, bb in zip(x2, point_boxes):
  858. point_prob,point_scores = heatmaps_to_circle_points(xx, bb,num_points=4)
  859. points_probs.append(point_prob.unsqueeze(1))
  860. points_scores.append(point_scores)
  861. return points_probs,points_scores
  862. def line_inference(x, line_boxes):
  863. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  864. lines_probs = []
  865. lines_scores = []
  866. boxes_per_image = [box.size(0) for box in line_boxes]
  867. x2 = x.split(boxes_per_image, dim=0)
  868. # x2:tuple 2 x2[0]:[1,3,1024,1024]
  869. # line_box: list:2 [1,4] [1.4] fasterrcnn kuang
  870. for xx, bb in zip(x2, line_boxes):
  871. line_prob, line_scores, = heatmaps_to_lines(xx, bb)
  872. lines_probs.append(line_prob)
  873. lines_scores.append(line_scores)
  874. return lines_probs, lines_scores
  875. def arc_inference(x, arc_boxes,th):
  876. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  877. points_probs = []
  878. points_scores = []
  879. print(f'arc_boxes:{len(arc_boxes)}')
  880. boxes_per_image = [box.size(0) for box in arc_boxes]
  881. print(f'arc boxes_per_image:{boxes_per_image}')
  882. x2 = x.split(boxes_per_image, dim=0)
  883. for xx, bb in zip(x2, arc_boxes):
  884. point_prob,point_scores = heatmaps_to_arc(xx, bb)
  885. points_probs.append(point_prob.unsqueeze(1))
  886. points_scores.append(point_scores)
  887. points_probs_tensor=torch.cat(points_probs)
  888. print(f'points_probs shape:{points_probs_tensor.shape}')
  889. feature_logits = x
  890. batch_size = feature_logits.shape[0]
  891. num_proposals = len(arc_boxes[0])
  892. results = [[torch.empty(0, 2) for _ in range(num_proposals)] for _ in range(batch_size)]
  893. proposals_list = arc_boxes[0] # [[tensor(...)]]
  894. for proposal_idx, proposal in enumerate(proposals_list):
  895. coords = proposal.tolist()
  896. x1, y1, x2, y2 = map(int, coords)
  897. x1 = max(0, x1)
  898. y1 = max(0, y1)
  899. x2 = min(feature_logits.shape[3], x2)
  900. y2 = min(feature_logits.shape[2], y2)
  901. for batch_idx in range(batch_size):
  902. region = feature_logits[batch_idx, :, y1:y2, x1:x2]
  903. mask = region > th
  904. coords = torch.nonzero(mask)
  905. if coords.numel() > 0:
  906. # 取 (y, x),然后转换为全局坐标 (x, y)
  907. local_coords = coords[:, [2, 1]] # (x, y)
  908. local_coords[:, 0] += x1
  909. local_coords[:, 1] += y1
  910. results[batch_idx][proposal_idx] = local_coords
  911. print(f're:{results}')
  912. return points_probs,points_scores,results
  913. import torch.nn.functional as F
  914. def heatmaps_to_arc(maps, rois, threshold=0, output_size=(128, 128)):
  915. """
  916. Args:
  917. maps: [N, 3, H, W] - full heatmaps
  918. rois: [N, 4] - bounding boxes
  919. threshold: float - binarization threshold
  920. output_size: resized size for uniform NMS
  921. Returns:
  922. masks: [N, 1, H, W] - binary mask aligned with input map
  923. scores: [N, 1] - count of non-zero pixels in each mask
  924. """
  925. N, _, H, W = maps.shape
  926. masks = torch.zeros((N, 1, H, W), dtype=torch.float32, device=maps.device)
  927. scores = torch.zeros((N, 1), dtype=torch.float32, device=maps.device)
  928. point_maps = maps[:, 0] # È¡µÚÒ»¸öͨµÀ [N, H, W]
  929. print(f"==> heatmaps_to_arc: maps.shape = {maps.shape}, rois.shape = {rois.shape}")
  930. for i in range(N):
  931. x1, y1, x2, y2 = rois[i].long()
  932. x1 = x1.clamp(0, W - 1)
  933. x2 = x2.clamp(0, W - 1)
  934. y1 = y1.clamp(0, H - 1)
  935. y2 = y2.clamp(0, H - 1)
  936. print(f"[{i}] roi: ({x1.item()}, {y1.item()}, {x2.item()}, {y2.item()})")
  937. if x2 <= x1 or y2 <= y1:
  938. print(f" Skipped invalid ROI at index {i}")
  939. continue
  940. roi_map = point_maps[i, y1:y2, x1:x2] # [h, w]
  941. print(f" roi_map.shape: {roi_map.shape}")
  942. if roi_map.numel() == 0:
  943. print(f" Skipped empty ROI at index {i}")
  944. continue
  945. # resize to uniform size
  946. roi_map_resized = F.interpolate(
  947. roi_map.unsqueeze(0).unsqueeze(0),
  948. size=output_size,
  949. mode='bilinear',
  950. align_corners=False
  951. ) # [1, 1, H, W]
  952. print(f" roi_map_resized.shape: {roi_map_resized.shape}")
  953. # NMS + threshold
  954. nms_roi = non_maximum_suppression(roi_map_resized) # shape: [1, H, W]
  955. bin_mask = (nms_roi >= threshold).float() # shape: [1, H, W]
  956. print(f" bin_mask.sum(): {bin_mask.sum().item()}")
  957. # resize back to original roi size
  958. h = int((y2 - y1).item())
  959. w = int((x2 - x1).item())
  960. # È·±£ bin_mask ÊÇ [1, 128, 128]
  961. assert bin_mask.dim() == 4, f"Expected 3D tensor [1, H, W], got {bin_mask.shape}"
  962. # ÉϲÉÑù»Ø ROI ԭʼ´óС
  963. bin_mask_original_size = F.interpolate(
  964. # bin_mask.unsqueeze(0), # ? [1, 1, 128, 128]
  965. bin_mask, # ? [1, 1, 128, 128]
  966. size=(h, w),
  967. mode='bilinear',
  968. align_corners=False
  969. )[0] # ? [1, h, w]
  970. masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size.squeeze()
  971. scores[i] = bin_mask_original_size.sum()
  972. # plt.figure(figsize=(6, 6))
  973. # plt.imshow(masks[i, 0].cpu().numpy(), cmap='gray')
  974. # plt.title(f"Mask {i}, score={scores[i].item():.1f}")
  975. # plt.axis('off')
  976. # plt.show()
  977. print(f" bin_mask_original_size.shape: {bin_mask_original_size.shape}, sum: {scores[i].item()}")
  978. print(f"==> Done. Total valid masks: {(scores > 0).sum().item()} / {N}")
  979. return masks, scores