head_losses.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681
  1. import torch
  2. from matplotlib import pyplot as plt
  3. import torch.nn.functional as F
  4. def features_align(features, proposals, img_size):
  5. print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}')
  6. align_feat_list = []
  7. for feat, proposals_per_img in zip(features, proposals):
  8. print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
  9. if proposals_per_img.shape[0]>0:
  10. feat = feat.unsqueeze(0)
  11. for proposal in proposals_per_img:
  12. align_feat = torch.zeros_like(feat)
  13. # print(f'align_feat:{align_feat.shape}')
  14. x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
  15. # 将每个proposal框内的部分赋值到align_feats对应位置
  16. align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
  17. align_feat_list.append(align_feat)
  18. # print(f'align_feat_list:{align_feat_list}')
  19. if len(align_feat_list) > 0:
  20. feats_tensor = torch.cat(align_feat_list)
  21. print(f'align features :{feats_tensor.shape}')
  22. else:
  23. feats_tensor = None
  24. return feats_tensor
  25. def normalize_tensor(t):
  26. return (t - t.min()) / (t.max() - t.min() + 1e-6)
  27. def line_length(lines):
  28. """
  29. 计算每条线段的长度
  30. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  31. 返回: [N]
  32. """
  33. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  34. def line_direction(lines):
  35. """
  36. 计算每条线段的单位方向向量
  37. lines: [N, 2, 2]
  38. 返回: [N, 2] 单位方向向量
  39. """
  40. vec = lines[:, 1] - lines[:, 0]
  41. return F.normalize(vec, dim=-1)
  42. def angle_loss_cosine(pred_dir, gt_dir):
  43. """
  44. 使用 cosine similarity 计算方向差异
  45. pred_dir: [N, 2]
  46. gt_dir: [N, 2]
  47. 返回: [N]
  48. """
  49. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  50. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  51. def line_length(lines):
  52. """
  53. 计算每条线段的长度
  54. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  55. 返回: [N]
  56. """
  57. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  58. def line_direction(lines):
  59. """
  60. 计算每条线段的单位方向向量
  61. lines: [N, 2, 2]
  62. 返回: [N, 2] 单位方向向量
  63. """
  64. vec = lines[:, 1] - lines[:, 0]
  65. return F.normalize(vec, dim=-1)
  66. def angle_loss_cosine(pred_dir, gt_dir):
  67. """
  68. 使用 cosine similarity 计算方向差异
  69. pred_dir: [N, 2]
  70. gt_dir: [N, 2]
  71. 返回: [N]
  72. """
  73. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  74. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  75. def single_point_to_heatmap(keypoints, rois, heatmap_size):
  76. # type: (Tensor, Tensor, int) -> Tensor
  77. print(f'rois:{rois.shape}')
  78. print(f'heatmap_size:{heatmap_size}')
  79. print(f'keypoints.shape:{keypoints.shape}')
  80. # batch_size, num_keypoints, _ = keypoints.shape
  81. x = keypoints[..., 0].unsqueeze(1)
  82. y = keypoints[..., 1].unsqueeze(1)
  83. gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
  84. # show_heatmap(gs[0],'target')
  85. all_roi_heatmap = []
  86. for roi, heatmap in zip(rois, gs):
  87. # show_heatmap(heatmap, 'target')
  88. # print(f'heatmap:{heatmap.shape}')
  89. heatmap = heatmap.unsqueeze(0)
  90. x1, y1, x2, y2 = map(int, roi)
  91. roi_heatmap = torch.zeros_like(heatmap)
  92. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  93. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  94. all_roi_heatmap.append(roi_heatmap)
  95. all_roi_heatmap = torch.cat(all_roi_heatmap)
  96. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  97. return all_roi_heatmap
  98. def line_points_to_heatmap(keypoints, rois, heatmap_size):
  99. # type: (Tensor, Tensor, int) -> Tensor
  100. print(f'rois:{rois.shape}')
  101. print(f'heatmap_size:{heatmap_size}')
  102. print(f'keypoints.shape:{keypoints.shape}')
  103. # batch_size, num_keypoints, _ = keypoints.shape
  104. x = keypoints[..., 0]
  105. y = keypoints[..., 1]
  106. gs = generate_gaussian_heatmaps(x, y, heatmap_size,num_points=2, sigma=1.0)
  107. # show_heatmap(gs[0],'target')
  108. all_roi_heatmap = []
  109. for roi, heatmap in zip(rois, gs):
  110. # print(f'heatmap:{heatmap.shape}')
  111. # show_heatmap(heatmap,'target')
  112. heatmap = heatmap.unsqueeze(0)
  113. x1, y1, x2, y2 = map(int, roi)
  114. roi_heatmap = torch.zeros_like(heatmap)
  115. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  116. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  117. all_roi_heatmap.append(roi_heatmap)
  118. if len(all_roi_heatmap) > 0:
  119. all_roi_heatmap = torch.cat(all_roi_heatmap)
  120. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  121. else:
  122. all_roi_heatmap = None
  123. return all_roi_heatmap
  124. """
  125. 修改适配的原结构的点 转热图,适用于带roi_pool版本的
  126. """
  127. def line_points_to_heatmap_(keypoints, rois, heatmap_size):
  128. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  129. print(f'rois:{rois.shape}')
  130. print(f'heatmap_size:{heatmap_size}')
  131. offset_x = rois[:, 0]
  132. offset_y = rois[:, 1]
  133. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  134. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  135. offset_x = offset_x[:, None]
  136. offset_y = offset_y[:, None]
  137. scale_x = scale_x[:, None]
  138. scale_y = scale_y[:, None]
  139. print(f'keypoints.shape:{keypoints.shape}')
  140. # batch_size, num_keypoints, _ = keypoints.shape
  141. x = keypoints[..., 0]
  142. y = keypoints[..., 1]
  143. # gs=generate_gaussian_heatmaps(x,y,512,1.0)
  144. # print(f'gs_heatmap shape:{gs.shape}')
  145. #
  146. # show_heatmap(gs[0],'target')
  147. x_boundary_inds = x == rois[:, 2][:, None]
  148. y_boundary_inds = y == rois[:, 3][:, None]
  149. x = (x - offset_x) * scale_x
  150. x = x.floor().long()
  151. y = (y - offset_y) * scale_y
  152. y = y.floor().long()
  153. x[x_boundary_inds] = heatmap_size - 1
  154. y[y_boundary_inds] = heatmap_size - 1
  155. # print(f'heatmaps x:{x}')
  156. # print(f'heatmaps y:{y}')
  157. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  158. vis = keypoints[..., 2] > 0
  159. valid = (valid_loc & vis).long()
  160. gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
  161. show_heatmap(gs_heatmap[0], 'feature')
  162. # print(f'gs_heatmap:{gs_heatmap.shape}')
  163. #
  164. # lin_ind = y * heatmap_size + x
  165. # print(f'lin_ind:{lin_ind.shape}')
  166. # heatmaps = lin_ind * valid
  167. return gs_heatmap
  168. def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  169. """
  170. 为一组点生成并合并高斯热图。
  171. Args:
  172. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  173. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  174. heatmap_size (int): 热图大小 H=W
  175. sigma (float): 高斯核标准差
  176. device (str): 设备类型 ('cpu' or 'cuda')
  177. Returns:
  178. Tensor: 形状为 (H, W) 的合并后的热图
  179. """
  180. assert xs.shape == ys.shape, "x and y must have the same shape"
  181. print(f'xs:{xs.shape}')
  182. N = xs.shape[0]
  183. print(f'N:{N},num_points:{num_points}')
  184. # 创建网格
  185. grid_y, grid_x = torch.meshgrid(
  186. torch.arange(heatmap_size, device=device),
  187. torch.arange(heatmap_size, device=device),
  188. indexing='ij'
  189. )
  190. # print(f'heatmap_size:{heatmap_size}')
  191. # 初始化输出热图
  192. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  193. for i in range(N):
  194. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  195. for j in range(num_points):
  196. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  197. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  198. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  199. # 计算距离平方
  200. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  201. # 计算高斯分布
  202. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  203. heatmap+=heatmap1
  204. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  205. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  206. #
  207. # # 计算距离平方
  208. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  209. #
  210. # # 计算高斯分布
  211. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  212. #
  213. # heatmap = heatmap1 + heatmap2
  214. # 将当前热图累加到结果中
  215. combined_heatmap[i] = heatmap
  216. return combined_heatmap
  217. def non_maximum_suppression(a):
  218. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  219. mask = (a == ap).float().clamp(min=0.0)
  220. return a * mask
  221. def heatmaps_to_points(maps, rois):
  222. point_preds = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  223. point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device)
  224. print(f'heatmaps_to_lines:{maps.shape}')
  225. point_maps=maps[:,0]
  226. print(f'point_map:{point_maps.shape}')
  227. for i in range(len(rois)):
  228. point_roi_map = point_maps[i].unsqueeze(0)
  229. print(f'point_roi_map:{point_roi_map.shape}')
  230. # roi_map_probs = scores_to_probs(roi_map.copy())
  231. w = point_roi_map.shape[2]
  232. flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
  233. point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
  234. print(f'point index:{point_index}')
  235. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  236. point_x =point_index % w
  237. point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
  238. point_preds[i, 0,] = point_x
  239. point_preds[i, 1,] = point_y
  240. point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
  241. return point_preds,point_end_scores
  242. def heatmaps_to_lines(maps, rois):
  243. line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
  244. line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  245. line_maps=maps[:,1]
  246. for i in range(len(rois)):
  247. line_roi_map = line_maps[i].unsqueeze(0)
  248. print(f'line_roi_map:{line_roi_map.shape}')
  249. # roi_map_probs = scores_to_probs(roi_map.copy())
  250. w = line_roi_map.shape[1]
  251. flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1)
  252. line_score, line_index = torch.topk(flatten_line_roi_map, k=2)
  253. print(f'line index:{line_index}')
  254. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  255. pos = line_index
  256. line_x = pos % w
  257. line_y = torch.div(pos - line_x, w, rounding_mode="floor")
  258. line_preds[i, 0, :] = line_x
  259. line_preds[i, 1, :] = line_y
  260. line_preds[i, 2, :] = 1
  261. line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x]
  262. return line_preds.permute(0, 2, 1), line_end_scores
  263. # 显示热图的函数
  264. def show_heatmap(heatmap, title="Heatmap"):
  265. """
  266. 使用 matplotlib 显示热图。
  267. Args:
  268. heatmap (Tensor): 要显示的热图张量
  269. title (str): 图表标题
  270. """
  271. # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组
  272. if heatmap.is_cuda:
  273. heatmap = heatmap.cpu().numpy()
  274. else:
  275. heatmap = heatmap.numpy()
  276. plt.imshow(heatmap, cmap='hot', interpolation='nearest')
  277. plt.colorbar()
  278. plt.title(title)
  279. plt.show()
  280. def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
  281. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  282. N, K, H, W = line_logits.shape
  283. len_proposals = len(proposals)
  284. print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals},line_matched_idxs:{line_matched_idxs}')
  285. if H != W:
  286. raise ValueError(
  287. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  288. )
  289. discretization_size = H
  290. heatmaps = []
  291. gs_heatmaps = []
  292. valid = []
  293. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
  294. print(f'line_proposals_per_image:{proposals_per_image.shape}')
  295. print(f'gt_lines:{gt_lines}')
  296. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  297. kp = gt_kp_in_image[midx]
  298. gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
  299. gs_heatmaps.append(gs_heatmaps_per_img)
  300. # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
  301. # heatmaps.append(heatmaps_per_image.view(-1))
  302. # valid.append(valid_per_image.view(-1))
  303. # line_targets = torch.cat(heatmaps, dim=0)
  304. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  305. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  306. # print(f'line_targets:{line_targets.shape},{line_targets}')
  307. # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  308. # valid = torch.where(valid)[0]
  309. # print(f' line_targets[valid]:{line_targets[valid]}')
  310. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  311. # accept empty tensors, so handle it sepaartely
  312. # if line_targets.numel() == 0 or len(valid) == 0:
  313. # return line_logits.sum() * 0
  314. # line_logits = line_logits.view(N * K, H * W)
  315. # print(f'line_logits[valid]:{line_logits[valid].shape}')
  316. print(f'loss1 line_logits:{line_logits.shape}')
  317. line_logits = line_logits[:,1,:,:]
  318. print(f'loss2 line_logits:{line_logits.shape}')
  319. # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
  320. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  321. return line_loss
  322. def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
  323. print(f'compute_arc_loss:{feature_logits.shape}')
  324. N, K, H, W = feature_logits.shape
  325. len_proposals = len(proposals)
  326. empty_count = 0
  327. non_empty_count = 0
  328. for prop in proposals:
  329. if prop.shape[0] == 0:
  330. empty_count += 1
  331. else:
  332. non_empty_count += 1
  333. print(f"Empty proposals count: {empty_count}")
  334. print(f"Non-empty proposals count: {non_empty_count}")
  335. print(f'starte to compute_point_loss')
  336. print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}')
  337. if H != W:
  338. raise ValueError(
  339. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  340. )
  341. discretization_size = H
  342. gs_heatmaps = []
  343. # print(f'point_matched_idxs:{point_matched_idxs}')
  344. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs):
  345. print(f'proposals_per_image:{proposals_per_image.shape}')
  346. kp = gt_kp_in_image[midx]
  347. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  348. gs_heatmaps_per_img = arc_points_to_heatmap(kp, proposals_per_image, discretization_size)
  349. gs_heatmaps.append(gs_heatmaps_per_img)
  350. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  351. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.squeeze(1).shape}')
  352. line_logits = feature_logits[:, 0]
  353. print(f'single_point_logits:{line_logits.shape}')
  354. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  355. return line_loss
  356. def arc_points_to_heatmap(keypoints, rois, heatmap_size):
  357. print(f'rois:{rois.shape}')
  358. print(f'heatmap_size:{heatmap_size}')
  359. print(f'keypoints.shape:{keypoints.shape}')
  360. # batch_size, num_keypoints, _ = keypoints.shape
  361. x = keypoints[..., 0].unsqueeze(1)
  362. y = keypoints[..., 1].unsqueeze(1)
  363. gs = generate_gaussian_heatmaps(x, y, num_points=10, heatmap_size=heatmap_size, sigma=1.0)
  364. # show_heatmap(gs[0],'target')
  365. all_roi_heatmap = []
  366. for roi, heatmap in zip(rois, gs):
  367. show_heatmap(heatmap, 'target')
  368. print(f'heatmap:{heatmap.shape}')
  369. heatmap = heatmap.unsqueeze(0)
  370. x1, y1, x2, y2 = map(int, roi)
  371. roi_heatmap = torch.zeros_like(heatmap)
  372. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  373. show_heatmap(roi_heatmap[0],'roi_heatmap')
  374. all_roi_heatmap.append(roi_heatmap)
  375. all_roi_heatmap = torch.cat(all_roi_heatmap)
  376. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  377. return all_roi_heatmap
  378. def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
  379. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  380. N, K, H, W = line_logits.shape
  381. len_proposals = len(proposals)
  382. empty_count = 0
  383. non_empty_count = 0
  384. for prop in proposals:
  385. if prop.shape[0] == 0:
  386. empty_count += 1
  387. else:
  388. non_empty_count += 1
  389. print(f"Empty proposals count: {empty_count}")
  390. print(f"Non-empty proposals count: {non_empty_count}")
  391. print(f'starte to compute_point_loss')
  392. print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
  393. if H != W:
  394. raise ValueError(
  395. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  396. )
  397. discretization_size = H
  398. gs_heatmaps = []
  399. # print(f'point_matched_idxs:{point_matched_idxs}')
  400. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs):
  401. print(f'proposals_per_image:{proposals_per_image.shape}')
  402. kp = gt_kp_in_image[midx]
  403. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  404. gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size)
  405. gs_heatmaps.append(gs_heatmaps_per_img)
  406. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  407. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  408. line_logits = line_logits[:,0]
  409. print(f'single_point_logits:{line_logits.shape}')
  410. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  411. return line_loss
  412. def lines_to_boxes(lines, img_size=511):
  413. """
  414. 输入:
  415. lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
  416. img_size: int,图像尺寸,用于 clamp 边界
  417. 输出:
  418. boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
  419. """
  420. # 提取所有线段的两个端点
  421. p1 = lines[:, 0] # (N, 2)
  422. p2 = lines[:, 1] # (N, 2)
  423. # 每条线段的 x 和 y 坐标
  424. x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1) # (N, 2)
  425. y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1) # (N, 2)
  426. # 计算包围盒边界
  427. x_min = x_coords.min(dim=1).values
  428. y_min = y_coords.min(dim=1).values
  429. x_max = x_coords.max(dim=1).values
  430. y_max = y_coords.max(dim=1).values
  431. # 扩展边界并限制在图像范围内
  432. x_min = (x_min - 1).clamp(min=0, max=img_size)
  433. y_min = (y_min - 1).clamp(min=0, max=img_size)
  434. x_max = (x_max + 1).clamp(min=0, max=img_size)
  435. y_max = (y_max + 1).clamp(min=0, max=img_size)
  436. # 合成包围盒
  437. boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1) # (N, 4)
  438. return boxes
  439. def box_iou_pairwise(box1, box2):
  440. """
  441. 输入:
  442. box1: shape (N, 4)
  443. box2: shape (M, 4)
  444. 输出:
  445. ious: shape (min(N, M), ), 只计算 i = j 的配对
  446. """
  447. N = min(len(box1), len(box2))
  448. lt = torch.max(box1[:N, :2], box2[:N, :2]) # 左上角
  449. rb = torch.min(box1[:N, 2:], box2[:N, 2:]) # 右下角
  450. wh = (rb - lt).clamp(min=0) # 宽高
  451. inter_area = wh[:, 0] * wh[:, 1] # 交集面积
  452. area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
  453. area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
  454. union_area = area1 + area2 - inter_area
  455. ious = inter_area / (union_area + 1e-6)
  456. return ious
  457. def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0):
  458. """
  459. Args:
  460. x: [N,1,H,W] 热力图
  461. boxes: [N,4] 框坐标
  462. gt_lines: [N,2,3] GT线段(含可见性)
  463. matched_idx: 匹配 index
  464. img_size: 图像尺寸
  465. alpha: IoU 损失权重
  466. beta: 长度损失权重
  467. gamma: 方向角度损失权重
  468. """
  469. losses = []
  470. boxes_per_image = [box.size(0) for box in boxes]
  471. x2 = x.split(boxes_per_image, dim=0)
  472. for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
  473. p_prob, _ = heatmaps_to_lines(xx, bb)
  474. pred_lines = p_prob
  475. gt_line_points = gt_line[mid]
  476. if len(pred_lines) == 0 or len(gt_line_points) == 0:
  477. continue
  478. # IoU 损失
  479. pred_boxes = lines_to_boxes(pred_lines, img_size)
  480. gt_boxes = lines_to_boxes(gt_line_points, img_size)
  481. ious = box_iou_pairwise(pred_boxes, gt_boxes)
  482. iou_loss = 1.0 - ious # [N]
  483. # 长度损失
  484. pred_len = line_length(pred_lines)
  485. gt_len = line_length(gt_line_points)
  486. length_diff = F.l1_loss(pred_len, gt_len, reduction='none') # [N]
  487. # 方向角度损失
  488. pred_dir = line_direction(pred_lines)
  489. gt_dir = line_direction(gt_line_points)
  490. ang_loss = angle_loss_cosine(pred_dir, gt_dir) # [N]
  491. # 归一化每一项损失
  492. norm_iou = normalize_tensor(iou_loss)
  493. norm_len = normalize_tensor(length_diff)
  494. norm_ang = normalize_tensor(ang_loss)
  495. total = alpha * norm_iou + beta * norm_len + gamma * norm_ang
  496. losses.append(total)
  497. if not losses:
  498. return None
  499. return torch.mean(torch.cat(losses))
  500. def point_inference(x, point_boxes):
  501. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  502. points_probs = []
  503. points_scores = []
  504. boxes_per_image = [box.size(0) for box in point_boxes]
  505. x2 = x.split(boxes_per_image, dim=0)
  506. for xx, bb in zip(x2, point_boxes):
  507. point_prob,point_scores = heatmaps_to_points(xx, bb)
  508. points_probs.append(point_prob.unsqueeze(1))
  509. points_scores.append(point_scores)
  510. return points_probs,points_scores
  511. def line_inference(x, line_boxes):
  512. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  513. lines_probs = []
  514. lines_scores = []
  515. boxes_per_image = [box.size(0) for box in line_boxes]
  516. x2 = x.split(boxes_per_image, dim=0)
  517. for xx, bb in zip(x2, line_boxes):
  518. line_prob, line_scores, = heatmaps_to_lines(xx, bb)
  519. lines_probs.append(line_prob)
  520. lines_scores.append(line_scores)
  521. return lines_probs, lines_scores