head_losses.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928
  1. import torch
  2. from matplotlib import pyplot as plt
  3. import torch.nn.functional as F
  4. from torch import nn
  5. class DiceLoss(nn.Module):
  6. def __init__(self, smooth=1.):
  7. super(DiceLoss, self).__init__()
  8. self.smooth = smooth
  9. def forward(self, logits, targets):
  10. probs = torch.sigmoid(logits)
  11. probs = probs.view(-1)
  12. targets = targets.view(-1).float()
  13. intersection = (probs * targets).sum()
  14. dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
  15. return 1. - dice
  16. bce_loss = nn.BCEWithLogitsLoss()
  17. dice_loss = DiceLoss()
  18. def combined_loss(preds, targets, alpha=0.5):
  19. bce = bce_loss(preds, targets)
  20. d = dice_loss(preds, targets)
  21. return alpha * bce + (1 - alpha) * d
  22. def features_align(features, proposals, img_size):
  23. print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}')
  24. align_feat_list = []
  25. for feat, proposals_per_img in zip(features, proposals):
  26. print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
  27. if proposals_per_img.shape[0]>0:
  28. feat = feat.unsqueeze(0)
  29. for proposal in proposals_per_img:
  30. align_feat = torch.zeros_like(feat)
  31. # print(f'align_feat:{align_feat.shape}')
  32. x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
  33. # 将每个proposal框内的部分赋值到align_feats对应位置
  34. align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
  35. align_feat_list.append(align_feat)
  36. # print(f'align_feat_list:{align_feat_list}')
  37. if len(align_feat_list) > 0:
  38. feats_tensor = torch.cat(align_feat_list)
  39. print(f'align features :{feats_tensor.shape}')
  40. else:
  41. feats_tensor = None
  42. return feats_tensor
  43. def normalize_tensor(t):
  44. return (t - t.min()) / (t.max() - t.min() + 1e-6)
  45. def line_length(lines):
  46. """
  47. 计算每条线段的长度
  48. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  49. 返回: [N]
  50. """
  51. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  52. def line_direction(lines):
  53. """
  54. 计算每条线段的单位方向向量
  55. lines: [N, 2, 2]
  56. 返回: [N, 2] 单位方向向量
  57. """
  58. vec = lines[:, 1] - lines[:, 0]
  59. return F.normalize(vec, dim=-1)
  60. def angle_loss_cosine(pred_dir, gt_dir):
  61. """
  62. 使用 cosine similarity 计算方向差异
  63. pred_dir: [N, 2]
  64. gt_dir: [N, 2]
  65. 返回: [N]
  66. """
  67. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  68. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  69. def line_length(lines):
  70. """
  71. 计算每条线段的长度
  72. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  73. 返回: [N]
  74. """
  75. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  76. def line_direction(lines):
  77. """
  78. 计算每条线段的单位方向向量
  79. lines: [N, 2, 2]
  80. 返回: [N, 2] 单位方向向量
  81. """
  82. vec = lines[:, 1] - lines[:, 0]
  83. return F.normalize(vec, dim=-1)
  84. def angle_loss_cosine(pred_dir, gt_dir):
  85. """
  86. 使用 cosine similarity 计算方向差异
  87. pred_dir: [N, 2]
  88. gt_dir: [N, 2]
  89. 返回: [N]
  90. """
  91. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  92. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  93. def single_point_to_heatmap(keypoints, rois, heatmap_size):
  94. # type: (Tensor, Tensor, int) -> Tensor
  95. print(f'rois:{rois.shape}')
  96. print(f'heatmap_size:{heatmap_size}')
  97. print(f'keypoints.shape:{keypoints.shape}')
  98. # batch_size, num_keypoints, _ = keypoints.shape
  99. x = keypoints[..., 0].unsqueeze(1)
  100. y = keypoints[..., 1].unsqueeze(1)
  101. gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
  102. # show_heatmap(gs[0],'target')
  103. all_roi_heatmap = []
  104. for roi, heatmap in zip(rois, gs):
  105. # show_heatmap(heatmap, 'target')
  106. # print(f'heatmap:{heatmap.shape}')
  107. heatmap = heatmap.unsqueeze(0)
  108. x1, y1, x2, y2 = map(int, roi)
  109. roi_heatmap = torch.zeros_like(heatmap)
  110. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  111. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  112. all_roi_heatmap.append(roi_heatmap)
  113. all_roi_heatmap = torch.cat(all_roi_heatmap)
  114. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  115. return all_roi_heatmap
  116. def line_points_to_heatmap(keypoints, rois, heatmap_size):
  117. # type: (Tensor, Tensor, int) -> Tensor
  118. print(f'rois:{rois.shape}')
  119. print(f'heatmap_size:{heatmap_size}')
  120. print(f'keypoints.shape:{keypoints.shape}')
  121. # batch_size, num_keypoints, _ = keypoints.shape
  122. x = keypoints[..., 0]
  123. y = keypoints[..., 1]
  124. gs = generate_gaussian_heatmaps(x, y, heatmap_size,num_points=2, sigma=1.0)
  125. # show_heatmap(gs[0],'target')
  126. all_roi_heatmap = []
  127. for roi, heatmap in zip(rois, gs):
  128. # print(f'heatmap:{heatmap.shape}')
  129. # show_heatmap(heatmap,'target')
  130. heatmap = heatmap.unsqueeze(0)
  131. x1, y1, x2, y2 = map(int, roi)
  132. roi_heatmap = torch.zeros_like(heatmap)
  133. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  134. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  135. all_roi_heatmap.append(roi_heatmap)
  136. if len(all_roi_heatmap) > 0:
  137. all_roi_heatmap = torch.cat(all_roi_heatmap)
  138. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  139. else:
  140. all_roi_heatmap = None
  141. return all_roi_heatmap
  142. """
  143. 修改适配的原结构的点 转热图,适用于带roi_pool版本的
  144. """
  145. def line_points_to_heatmap_(keypoints, rois, heatmap_size):
  146. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  147. print(f'rois:{rois.shape}')
  148. print(f'heatmap_size:{heatmap_size}')
  149. offset_x = rois[:, 0]
  150. offset_y = rois[:, 1]
  151. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  152. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  153. offset_x = offset_x[:, None]
  154. offset_y = offset_y[:, None]
  155. scale_x = scale_x[:, None]
  156. scale_y = scale_y[:, None]
  157. print(f'keypoints.shape:{keypoints.shape}')
  158. # batch_size, num_keypoints, _ = keypoints.shape
  159. x = keypoints[..., 0]
  160. y = keypoints[..., 1]
  161. # gs=generate_gaussian_heatmaps(x,y,512,1.0)
  162. # print(f'gs_heatmap shape:{gs.shape}')
  163. #
  164. # show_heatmap(gs[0],'target')
  165. x_boundary_inds = x == rois[:, 2][:, None]
  166. y_boundary_inds = y == rois[:, 3][:, None]
  167. x = (x - offset_x) * scale_x
  168. x = x.floor().long()
  169. y = (y - offset_y) * scale_y
  170. y = y.floor().long()
  171. x[x_boundary_inds] = heatmap_size - 1
  172. y[y_boundary_inds] = heatmap_size - 1
  173. # print(f'heatmaps x:{x}')
  174. # print(f'heatmaps y:{y}')
  175. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  176. vis = keypoints[..., 2] > 0
  177. valid = (valid_loc & vis).long()
  178. gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
  179. # show_heatmap(gs_heatmap[0], 'feature')
  180. # print(f'gs_heatmap:{gs_heatmap.shape}')
  181. #
  182. # lin_ind = y * heatmap_size + x
  183. # print(f'lin_ind:{lin_ind.shape}')
  184. # heatmaps = lin_ind * valid
  185. return gs_heatmap
  186. def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  187. """
  188. 为一组点生成并合并高斯热图。
  189. Args:
  190. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  191. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  192. heatmap_size (int): 热图大小 H=W
  193. sigma (float): 高斯核标准差
  194. device (str): 设备类型 ('cpu' or 'cuda')
  195. Returns:
  196. Tensor: 形状为 (H, W) 的合并后的热图
  197. """
  198. assert xs.shape == ys.shape, "x and y must have the same shape"
  199. print(f'xs:{xs.shape}')
  200. # xs=xs.squeeze(1)
  201. # ys = ys.squeeze(1)
  202. print(f'xs1:{xs.shape}')
  203. N = xs.shape[0]
  204. print(f'N:{N},num_points:{num_points}')
  205. # 创建网格
  206. grid_y, grid_x = torch.meshgrid(
  207. torch.arange(heatmap_size, device=device),
  208. torch.arange(heatmap_size, device=device),
  209. indexing='ij'
  210. )
  211. # print(f'heatmap_size:{heatmap_size}')
  212. # 初始化输出热图
  213. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  214. for i in range(N):
  215. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  216. for j in range(num_points):
  217. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  218. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  219. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  220. # 计算距离平方
  221. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  222. # 计算高斯分布
  223. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  224. heatmap+=heatmap1
  225. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  226. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  227. #
  228. # # 计算距离平方
  229. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  230. #
  231. # # 计算高斯分布
  232. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  233. #
  234. # heatmap = heatmap1 + heatmap2
  235. # 将当前热图累加到结果中
  236. combined_heatmap[i] = heatmap
  237. return combined_heatmap
  238. def generate_mask_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  239. """
  240. 为一组点生成并合并高斯热图。
  241. Args:
  242. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  243. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  244. heatmap_size (int): 热图大小 H=W
  245. sigma (float): 高斯核标准差
  246. device (str): 设备类型 ('cpu' or 'cuda')
  247. Returns:
  248. Tensor: 形状为 (H, W) 的合并后的热图
  249. """
  250. assert xs.shape == ys.shape, "x and y must have the same shape"
  251. print(f'xs:{xs.shape}')
  252. xs=xs.squeeze(1)
  253. ys = ys.squeeze(1)
  254. print(f'xs1:{xs.shape}')
  255. N = xs.shape[0]
  256. print(f'N:{N},num_points:{num_points}')
  257. # 创建网格
  258. grid_y, grid_x = torch.meshgrid(
  259. torch.arange(heatmap_size, device=device),
  260. torch.arange(heatmap_size, device=device),
  261. indexing='ij'
  262. )
  263. # print(f'heatmap_size:{heatmap_size}')
  264. # 初始化输出热图
  265. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  266. for i in range(N):
  267. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  268. for j in range(num_points):
  269. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  270. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  271. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  272. # 计算距离平方
  273. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  274. # 计算高斯分布
  275. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  276. heatmap+=heatmap1
  277. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  278. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  279. #
  280. # # 计算距离平方
  281. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  282. #
  283. # # 计算高斯分布
  284. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  285. #
  286. # heatmap = heatmap1 + heatmap2
  287. # 将当前热图累加到结果中
  288. combined_heatmap[i] = heatmap
  289. return combined_heatmap
  290. def non_maximum_suppression(a):
  291. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  292. mask = (a == ap).float().clamp(min=0.0)
  293. return a * mask
  294. def heatmaps_to_points(maps, rois):
  295. point_preds = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  296. point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device)
  297. print(f'heatmaps_to_lines:{maps.shape}')
  298. point_maps=maps[:,0]
  299. print(f'point_map:{point_maps.shape}')
  300. for i in range(len(rois)):
  301. point_roi_map = point_maps[i].unsqueeze(0)
  302. print(f'point_roi_map:{point_roi_map.shape}')
  303. # roi_map_probs = scores_to_probs(roi_map.copy())
  304. w = point_roi_map.shape[2]
  305. flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
  306. point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
  307. print(f'point index:{point_index}')
  308. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  309. point_x =point_index % w
  310. point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
  311. point_preds[i, 0,] = point_x
  312. point_preds[i, 1,] = point_y
  313. point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
  314. return point_preds,point_end_scores
  315. def heatmaps_to_lines(maps, rois):
  316. line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
  317. line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  318. line_maps=maps[:,1]
  319. for i in range(len(rois)):
  320. line_roi_map = line_maps[i].unsqueeze(0)
  321. print(f'line_roi_map:{line_roi_map.shape}')
  322. # roi_map_probs = scores_to_probs(roi_map.copy())
  323. w = line_roi_map.shape[1]
  324. flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1)
  325. line_score, line_index = torch.topk(flatten_line_roi_map, k=2)
  326. print(f'line index:{line_index}')
  327. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  328. pos = line_index
  329. line_x = pos % w
  330. line_y = torch.div(pos - line_x, w, rounding_mode="floor")
  331. line_preds[i, 0, :] = line_x
  332. line_preds[i, 1, :] = line_y
  333. line_preds[i, 2, :] = 1
  334. line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x]
  335. return line_preds.permute(0, 2, 1), line_end_scores
  336. # 显示热图的函数
  337. def show_heatmap(heatmap, title="Heatmap"):
  338. """
  339. 使用 matplotlib 显示热图。
  340. Args:
  341. heatmap (Tensor): 要显示的热图张量
  342. title (str): 图表标题
  343. """
  344. # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组
  345. if heatmap.is_cuda:
  346. heatmap = heatmap.cpu().numpy()
  347. else:
  348. heatmap = heatmap.numpy()
  349. plt.imshow(heatmap, cmap='hot', interpolation='nearest')
  350. plt.colorbar()
  351. plt.title(title)
  352. plt.show()
  353. def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
  354. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  355. N, K, H, W = line_logits.shape
  356. len_proposals = len(proposals)
  357. print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals},line_matched_idxs:{line_matched_idxs}')
  358. if H != W:
  359. raise ValueError(
  360. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  361. )
  362. discretization_size = H
  363. heatmaps = []
  364. gs_heatmaps = []
  365. valid = []
  366. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
  367. print(f'line_proposals_per_image:{proposals_per_image.shape}')
  368. print(f'gt_lines:{gt_lines}')
  369. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  370. kp = gt_kp_in_image[midx]
  371. gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
  372. gs_heatmaps.append(gs_heatmaps_per_img)
  373. # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
  374. # heatmaps.append(heatmaps_per_image.view(-1))
  375. # valid.append(valid_per_image.view(-1))
  376. # line_targets = torch.cat(heatmaps, dim=0)
  377. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  378. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  379. # print(f'line_targets:{line_targets.shape},{line_targets}')
  380. # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  381. # valid = torch.where(valid)[0]
  382. # print(f' line_targets[valid]:{line_targets[valid]}')
  383. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  384. # accept empty tensors, so handle it sepaartely
  385. # if line_targets.numel() == 0 or len(valid) == 0:
  386. # return line_logits.sum() * 0
  387. # line_logits = line_logits.view(N * K, H * W)
  388. # print(f'line_logits[valid]:{line_logits[valid].shape}')
  389. print(f'loss1 line_logits:{line_logits.shape}')
  390. line_logits = line_logits[:,1,:,:]
  391. print(f'loss2 line_logits:{line_logits.shape}')
  392. # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
  393. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  394. return line_loss
  395. def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
  396. print(f'compute_arc_loss:{feature_logits.shape}')
  397. N, K, H, W = feature_logits.shape
  398. len_proposals = len(proposals)
  399. empty_count = 0
  400. non_empty_count = 0
  401. for prop in proposals:
  402. if prop.shape[0] == 0:
  403. empty_count += 1
  404. else:
  405. non_empty_count += 1
  406. print(f"Empty proposals count: {empty_count}")
  407. print(f"Non-empty proposals count: {non_empty_count}")
  408. print(f'starte to compute_point_loss')
  409. print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}')
  410. if H != W:
  411. raise ValueError(
  412. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  413. )
  414. discretization_size = H
  415. gs_heatmaps = []
  416. # print(f'point_matched_idxs:{point_matched_idxs}')
  417. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs):
  418. # [
  419. # (Tensor(38, 4), Tensor(1, 57, 2), Tensor(38, 1)),
  420. # (Tensor(65, 4), Tensor(1, 74, 2), Tensor(65, 1))
  421. # ]
  422. print(f'proposals_per_image:{proposals_per_image.shape}')
  423. kp = gt_kp_in_image[midx]
  424. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  425. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  426. gs_heatmaps_per_img = arc_points_to_heatmap(kp, proposals_per_image, discretization_size)
  427. gs_heatmaps.append(gs_heatmaps_per_img)
  428. if len(gs_heatmaps)>0:
  429. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  430. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}')
  431. line_logits = feature_logits.squeeze(1)
  432. print(f'single_point_logits:{line_logits.shape}')
  433. # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
  434. # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  435. line_loss=combined_loss(line_logits, gs_heatmaps)
  436. else:
  437. line_loss=100
  438. print("d")
  439. return line_loss
  440. def arc_points_to_heatmap(keypoints, rois, heatmap_size):
  441. print(f'rois:{rois.shape}')
  442. print(f'heatmap_size:{heatmap_size}')
  443. print(f'keypoints.shape:{keypoints.shape}')
  444. # batch_size, num_keypoints, _ = keypoints.shape
  445. x = keypoints[..., 0].unsqueeze(1)
  446. y = keypoints[..., 1].unsqueeze(1)
  447. num_points=x.shape[2]
  448. print(f'num_points:{num_points}')
  449. gs = generate_mask_gaussian_heatmaps(x, y, num_points=num_points, heatmap_size=heatmap_size, sigma=2.0)
  450. # show_heatmap(gs[0],'target')
  451. all_roi_heatmap = []
  452. for roi, heatmap in zip(rois, gs):
  453. # show_heatmap(heatmap, 'target')
  454. print(f'heatmap:{heatmap.shape}')
  455. heatmap = heatmap.unsqueeze(0)
  456. x1, y1, x2, y2 = map(int, roi)
  457. roi_heatmap = torch.zeros_like(heatmap)
  458. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  459. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  460. all_roi_heatmap.append(roi_heatmap)
  461. all_roi_heatmap = torch.cat(all_roi_heatmap)
  462. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  463. return all_roi_heatmap
  464. def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
  465. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  466. N, K, H, W = line_logits.shape
  467. len_proposals = len(proposals)
  468. empty_count = 0
  469. non_empty_count = 0
  470. for prop in proposals:
  471. if prop.shape[0] == 0:
  472. empty_count += 1
  473. else:
  474. non_empty_count += 1
  475. print(f"Empty proposals count: {empty_count}")
  476. print(f"Non-empty proposals count: {non_empty_count}")
  477. print(f'starte to compute_point_loss')
  478. print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
  479. if H != W:
  480. raise ValueError(
  481. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  482. )
  483. discretization_size = H
  484. gs_heatmaps = []
  485. # print(f'point_matched_idxs:{point_matched_idxs}')
  486. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs):
  487. print(f'proposals_per_image:{proposals_per_image.shape}')
  488. kp = gt_kp_in_image[midx]
  489. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  490. gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size)
  491. gs_heatmaps.append(gs_heatmaps_per_img)
  492. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  493. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  494. line_logits = line_logits[:,0]
  495. print(f'single_point_logits:{line_logits.shape}')
  496. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  497. return line_loss
  498. def lines_to_boxes(lines, img_size=511):
  499. """
  500. 输入:
  501. lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
  502. img_size: int,图像尺寸,用于 clamp 边界
  503. 输出:
  504. boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
  505. """
  506. # 提取所有线段的两个端点
  507. p1 = lines[:, 0] # (N, 2)
  508. p2 = lines[:, 1] # (N, 2)
  509. # 每条线段的 x 和 y 坐标
  510. x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1) # (N, 2)
  511. y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1) # (N, 2)
  512. # 计算包围盒边界
  513. x_min = x_coords.min(dim=1).values
  514. y_min = y_coords.min(dim=1).values
  515. x_max = x_coords.max(dim=1).values
  516. y_max = y_coords.max(dim=1).values
  517. # 扩展边界并限制在图像范围内
  518. x_min = (x_min - 1).clamp(min=0, max=img_size)
  519. y_min = (y_min - 1).clamp(min=0, max=img_size)
  520. x_max = (x_max + 1).clamp(min=0, max=img_size)
  521. y_max = (y_max + 1).clamp(min=0, max=img_size)
  522. # 合成包围盒
  523. boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1) # (N, 4)
  524. return boxes
  525. def box_iou_pairwise(box1, box2):
  526. """
  527. 输入:
  528. box1: shape (N, 4)
  529. box2: shape (M, 4)
  530. 输出:
  531. ious: shape (min(N, M), ), 只计算 i = j 的配对
  532. """
  533. N = min(len(box1), len(box2))
  534. lt = torch.max(box1[:N, :2], box2[:N, :2]) # 左上角
  535. rb = torch.min(box1[:N, 2:], box2[:N, 2:]) # 右下角
  536. wh = (rb - lt).clamp(min=0) # 宽高
  537. inter_area = wh[:, 0] * wh[:, 1] # 交集面积
  538. area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
  539. area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
  540. union_area = area1 + area2 - inter_area
  541. ious = inter_area / (union_area + 1e-6)
  542. return ious
  543. def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0):
  544. """
  545. Args:
  546. x: [N,1,H,W] 热力图
  547. boxes: [N,4] 框坐标
  548. gt_lines: [N,2,3] GT线段(含可见性)
  549. matched_idx: 匹配 index
  550. img_size: 图像尺寸
  551. alpha: IoU 损失权重
  552. beta: 长度损失权重
  553. gamma: 方向角度损失权重
  554. """
  555. losses = []
  556. boxes_per_image = [box.size(0) for box in boxes]
  557. x2 = x.split(boxes_per_image, dim=0)
  558. for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
  559. p_prob, _ = heatmaps_to_lines(xx, bb)
  560. pred_lines = p_prob
  561. gt_line_points = gt_line[mid]
  562. if len(pred_lines) == 0 or len(gt_line_points) == 0:
  563. continue
  564. # IoU 损失
  565. pred_boxes = lines_to_boxes(pred_lines, img_size)
  566. gt_boxes = lines_to_boxes(gt_line_points, img_size)
  567. ious = box_iou_pairwise(pred_boxes, gt_boxes)
  568. iou_loss = 1.0 - ious # [N]
  569. # 长度损失
  570. pred_len = line_length(pred_lines)
  571. gt_len = line_length(gt_line_points)
  572. length_diff = F.l1_loss(pred_len, gt_len, reduction='none') # [N]
  573. # 方向角度损失
  574. pred_dir = line_direction(pred_lines)
  575. gt_dir = line_direction(gt_line_points)
  576. ang_loss = angle_loss_cosine(pred_dir, gt_dir) # [N]
  577. # 归一化每一项损失
  578. norm_iou = normalize_tensor(iou_loss)
  579. norm_len = normalize_tensor(length_diff)
  580. norm_ang = normalize_tensor(ang_loss)
  581. total = alpha * norm_iou + beta * norm_len + gamma * norm_ang
  582. losses.append(total)
  583. if not losses:
  584. return None
  585. return torch.mean(torch.cat(losses))
  586. def point_inference(x, point_boxes):
  587. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  588. points_probs = []
  589. points_scores = []
  590. boxes_per_image = [box.size(0) for box in point_boxes]
  591. x2 = x.split(boxes_per_image, dim=0)
  592. for xx, bb in zip(x2, point_boxes):
  593. point_prob,point_scores = heatmaps_to_points(xx, bb)
  594. points_probs.append(point_prob.unsqueeze(1))
  595. points_scores.append(point_scores)
  596. return points_probs,points_scores
  597. def line_inference(x, line_boxes):
  598. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  599. lines_probs = []
  600. lines_scores = []
  601. boxes_per_image = [box.size(0) for box in line_boxes]
  602. x2 = x.split(boxes_per_image, dim=0)
  603. # x2:tuple 2 x2[0]:[1,3,1024,1024]
  604. # line_box: list:2 [1,4] [1.4] fasterrcnn kuang
  605. for xx, bb in zip(x2, line_boxes):
  606. line_prob, line_scores, = heatmaps_to_lines(xx, bb)
  607. lines_probs.append(line_prob)
  608. lines_scores.append(line_scores)
  609. return lines_probs, lines_scores
  610. def arc_inference(x, arc_boxes,th):
  611. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  612. points_probs = []
  613. points_scores = []
  614. print(f'arc_boxes:{len(arc_boxes)}')
  615. boxes_per_image = [box.size(0) for box in arc_boxes]
  616. print(f'arc boxes_per_image:{boxes_per_image}')
  617. x2 = x.split(boxes_per_image, dim=0)
  618. for xx, bb in zip(x2, arc_boxes):
  619. point_prob,point_scores = heatmaps_to_arc(xx, bb)
  620. points_probs.append(point_prob.unsqueeze(1))
  621. points_scores.append(point_scores)
  622. points_probs_tensor=torch.cat(points_probs)
  623. print(f'points_probs shape:{points_probs_tensor.shape}')
  624. feature_logits = x
  625. batch_size = feature_logits.shape[0]
  626. num_proposals = len(arc_boxes[0])
  627. results = [[torch.empty(0, 2) for _ in range(num_proposals)] for _ in range(batch_size)]
  628. proposals_list = arc_boxes[0] # [[tensor(...)]]
  629. for proposal_idx, proposal in enumerate(proposals_list):
  630. coords = proposal.tolist()
  631. x1, y1, x2, y2 = map(int, coords)
  632. x1 = max(0, x1)
  633. y1 = max(0, y1)
  634. x2 = min(feature_logits.shape[3], x2)
  635. y2 = min(feature_logits.shape[2], y2)
  636. for batch_idx in range(batch_size):
  637. region = feature_logits[batch_idx, :, y1:y2, x1:x2]
  638. mask = region > th
  639. coords = torch.nonzero(mask)
  640. if coords.numel() > 0:
  641. # 取 (y, x),然后转换为全局坐标 (x, y)
  642. local_coords = coords[:, [2, 1]] # (x, y)
  643. local_coords[:, 0] += x1
  644. local_coords[:, 1] += y1
  645. results[batch_idx][proposal_idx] = local_coords
  646. print(f're:{results}')
  647. return points_probs,points_scores,results
  648. import torch.nn.functional as F
  649. def heatmaps_to_arc(maps, rois, threshold=0.1, output_size=(128, 128)):
  650. """
  651. Args:
  652. maps: [N, 3, H, W] - full heatmaps
  653. rois: [N, 4] - bounding boxes
  654. threshold: float - binarization threshold
  655. output_size: resized size for uniform NMS
  656. Returns:
  657. masks: [N, 1, H, W] - binary mask aligned with input map
  658. scores: [N, 1] - count of non-zero pixels in each mask
  659. """
  660. N, _, H, W = maps.shape
  661. masks = torch.zeros((N, 1, H, W), dtype=torch.float32, device=maps.device)
  662. scores = torch.zeros((N, 1), dtype=torch.float32, device=maps.device)
  663. point_maps = maps[:, 0] # È¡µÚÒ»¸öͨµÀ [N, H, W]
  664. print(f"==> heatmaps_to_arc: maps.shape = {maps.shape}, rois.shape = {rois.shape}")
  665. for i in range(N):
  666. x1, y1, x2, y2 = rois[i].long()
  667. x1 = x1.clamp(0, W - 1)
  668. x2 = x2.clamp(0, W - 1)
  669. y1 = y1.clamp(0, H - 1)
  670. y2 = y2.clamp(0, H - 1)
  671. print(f"[{i}] roi: ({x1.item()}, {y1.item()}, {x2.item()}, {y2.item()})")
  672. if x2 <= x1 or y2 <= y1:
  673. print(f" Skipped invalid ROI at index {i}")
  674. continue
  675. roi_map = point_maps[i, y1:y2, x1:x2] # [h, w]
  676. print(f" roi_map.shape: {roi_map.shape}")
  677. if roi_map.numel() == 0:
  678. print(f" Skipped empty ROI at index {i}")
  679. continue
  680. # resize to uniform size
  681. roi_map_resized = F.interpolate(
  682. roi_map.unsqueeze(0).unsqueeze(0),
  683. size=output_size,
  684. mode='bilinear',
  685. align_corners=False
  686. ) # [1, 1, H, W]
  687. print(f" roi_map_resized.shape: {roi_map_resized.shape}")
  688. # NMS + threshold
  689. nms_roi = non_maximum_suppression(roi_map_resized) # shape: [1, H, W]
  690. bin_mask = (nms_roi > threshold).float() # shape: [1, H, W]
  691. print(f" bin_mask.sum(): {bin_mask.sum().item()}")
  692. # resize back to original roi size
  693. h = int((y2 - y1).item())
  694. w = int((x2 - x1).item())
  695. # È·±£ bin_mask ÊÇ [1, 128, 128]
  696. assert bin_mask.dim() == 4, f"Expected 3D tensor [1, H, W], got {bin_mask.shape}"
  697. # ÉϲÉÑù»Ø ROI ԭʼ´óС
  698. bin_mask_original_size = F.interpolate(
  699. # bin_mask.unsqueeze(0), # ? [1, 1, 128, 128]
  700. bin_mask, # ? [1, 1, 128, 128]
  701. size=(h, w),
  702. mode='bilinear',
  703. align_corners=False
  704. )[0] # ? [1, h, w]
  705. masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size.squeeze()
  706. scores[i] = bin_mask_original_size.sum()
  707. print(f" bin_mask_original_size.shape: {bin_mask_original_size.shape}, sum: {scores[i].item()}")
  708. print(f"==> Done. Total valid masks: {(scores > 0).sum().item()} / {N}")
  709. return masks, scores