head_losses.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023
  1. import torch
  2. from matplotlib import pyplot as plt
  3. import torch.nn.functional as F
  4. from torch import nn
  5. class DiceLoss(nn.Module):
  6. def __init__(self, smooth=1.):
  7. super(DiceLoss, self).__init__()
  8. self.smooth = smooth
  9. def forward(self, logits, targets):
  10. probs = torch.sigmoid(logits)
  11. probs = probs.view(-1)
  12. targets = targets.view(-1).float()
  13. intersection = (probs * targets).sum()
  14. dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
  15. return 1. - dice
  16. bce_loss = nn.BCEWithLogitsLoss()
  17. dice_loss = DiceLoss()
  18. def combined_loss(preds, targets, alpha=0.5):
  19. bce = bce_loss(preds, targets)
  20. d = dice_loss(preds, targets)
  21. return alpha * bce + (1 - alpha) * d
  22. def features_align(features, proposals, img_size):
  23. print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}')
  24. align_feat_list = []
  25. for feat, proposals_per_img in zip(features, proposals):
  26. print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
  27. if proposals_per_img.shape[0]>0:
  28. feat = feat.unsqueeze(0)
  29. for proposal in proposals_per_img:
  30. align_feat = torch.zeros_like(feat)
  31. # print(f'align_feat:{align_feat.shape}')
  32. x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
  33. # 将每个proposal框内的部分赋值到align_feats对应位置
  34. align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
  35. align_feat_list.append(align_feat)
  36. # print(f'align_feat_list:{align_feat_list}')
  37. if len(align_feat_list) > 0:
  38. feats_tensor = torch.cat(align_feat_list)
  39. print(f'align features :{feats_tensor.shape}')
  40. else:
  41. feats_tensor = None
  42. return feats_tensor
  43. def normalize_tensor(t):
  44. return (t - t.min()) / (t.max() - t.min() + 1e-6)
  45. def line_length(lines):
  46. """
  47. 计算每条线段的长度
  48. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  49. 返回: [N]
  50. """
  51. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  52. def line_direction(lines):
  53. """
  54. 计算每条线段的单位方向向量
  55. lines: [N, 2, 2]
  56. 返回: [N, 2] 单位方向向量
  57. """
  58. vec = lines[:, 1] - lines[:, 0]
  59. return F.normalize(vec, dim=-1)
  60. def angle_loss_cosine(pred_dir, gt_dir):
  61. """
  62. 使用 cosine similarity 计算方向差异
  63. pred_dir: [N, 2]
  64. gt_dir: [N, 2]
  65. 返回: [N]
  66. """
  67. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  68. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  69. def line_length(lines):
  70. """
  71. 计算每条线段的长度
  72. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  73. 返回: [N]
  74. """
  75. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  76. def line_direction(lines):
  77. """
  78. 计算每条线段的单位方向向量
  79. lines: [N, 2, 2]
  80. 返回: [N, 2] 单位方向向量
  81. """
  82. vec = lines[:, 1] - lines[:, 0]
  83. return F.normalize(vec, dim=-1)
  84. def angle_loss_cosine(pred_dir, gt_dir):
  85. """
  86. 使用 cosine similarity 计算方向差异
  87. pred_dir: [N, 2]
  88. gt_dir: [N, 2]
  89. 返回: [N]
  90. """
  91. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  92. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  93. def single_point_to_heatmap(keypoints, rois, heatmap_size):
  94. # type: (Tensor, Tensor, int) -> Tensor
  95. print(f'rois:{rois.shape}')
  96. print(f'heatmap_size:{heatmap_size}')
  97. print(f'keypoints.shape:{keypoints.shape}')
  98. # batch_size, num_keypoints, _ = keypoints.shape
  99. x = keypoints[..., 0].unsqueeze(1)
  100. y = keypoints[..., 1].unsqueeze(1)
  101. gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
  102. # show_heatmap(gs[0],'target')
  103. all_roi_heatmap = []
  104. for roi, heatmap in zip(rois, gs):
  105. # show_heatmap(heatmap, 'target')
  106. # print(f'heatmap:{heatmap.shape}')
  107. heatmap = heatmap.unsqueeze(0)
  108. x1, y1, x2, y2 = map(int, roi)
  109. roi_heatmap = torch.zeros_like(heatmap)
  110. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  111. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  112. all_roi_heatmap.append(roi_heatmap)
  113. all_roi_heatmap = torch.cat(all_roi_heatmap)
  114. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  115. return all_roi_heatmap
  116. def points_to_heatmap(keypoints, rois,num_points=2, heatmap_size=(512,512)):
  117. # type: (Tensor, Tensor, int) -> Tensor
  118. print(f'rois:{rois.shape}')
  119. print(f'heatmap_size:{heatmap_size}')
  120. print(f'keypoints.shape:{keypoints.shape}')
  121. # batch_size, num_keypoints, _ = keypoints.shape
  122. x = keypoints[..., 0].unsqueeze(1)
  123. y = keypoints[..., 1].unsqueeze(1)
  124. gs = generate_gaussian_heatmaps(x, y,num_points=num_points, heatmap_size=heatmap_size, sigma=1.0)
  125. # show_heatmap(gs[0],'target')
  126. all_roi_heatmap = []
  127. for roi, heatmap in zip(rois, gs):
  128. # show_heatmap(heatmap, 'target')
  129. # print(f'heatmap:{heatmap.shape}')
  130. heatmap = heatmap.unsqueeze(0)
  131. x1, y1, x2, y2 = map(int, roi)
  132. roi_heatmap = torch.zeros_like(heatmap)
  133. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  134. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  135. all_roi_heatmap.append(roi_heatmap)
  136. all_roi_heatmap = torch.cat(all_roi_heatmap)
  137. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  138. return all_roi_heatmap
  139. def line_points_to_heatmap(keypoints, rois, heatmap_size):
  140. # type: (Tensor, Tensor, int) -> Tensor
  141. print(f'rois:{rois.shape}')
  142. print(f'heatmap_size:{heatmap_size}')
  143. print(f'keypoints.shape:{keypoints.shape}')
  144. # batch_size, num_keypoints, _ = keypoints.shape
  145. x = keypoints[..., 0]
  146. y = keypoints[..., 1]
  147. gs = generate_gaussian_heatmaps(x, y, heatmap_size,num_points=2, sigma=1.0)
  148. # show_heatmap(gs[0],'target')
  149. all_roi_heatmap = []
  150. for roi, heatmap in zip(rois, gs):
  151. # print(f'heatmap:{heatmap.shape}')
  152. # show_heatmap(heatmap,'target')
  153. heatmap = heatmap.unsqueeze(0)
  154. x1, y1, x2, y2 = map(int, roi)
  155. roi_heatmap = torch.zeros_like(heatmap)
  156. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  157. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  158. all_roi_heatmap.append(roi_heatmap)
  159. if len(all_roi_heatmap) > 0:
  160. all_roi_heatmap = torch.cat(all_roi_heatmap)
  161. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  162. else:
  163. all_roi_heatmap = None
  164. return all_roi_heatmap
  165. """
  166. 修改适配的原结构的点 转热图,适用于带roi_pool版本的
  167. """
  168. def line_points_to_heatmap_(keypoints, rois, heatmap_size):
  169. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  170. print(f'rois:{rois.shape}')
  171. print(f'heatmap_size:{heatmap_size}')
  172. offset_x = rois[:, 0]
  173. offset_y = rois[:, 1]
  174. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  175. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  176. offset_x = offset_x[:, None]
  177. offset_y = offset_y[:, None]
  178. scale_x = scale_x[:, None]
  179. scale_y = scale_y[:, None]
  180. print(f'keypoints.shape:{keypoints.shape}')
  181. # batch_size, num_keypoints, _ = keypoints.shape
  182. x = keypoints[..., 0]
  183. y = keypoints[..., 1]
  184. # gs=generate_gaussian_heatmaps(x,y,512,1.0)
  185. # print(f'gs_heatmap shape:{gs.shape}')
  186. #
  187. # show_heatmap(gs[0],'target')
  188. x_boundary_inds = x == rois[:, 2][:, None]
  189. y_boundary_inds = y == rois[:, 3][:, None]
  190. x = (x - offset_x) * scale_x
  191. x = x.floor().long()
  192. y = (y - offset_y) * scale_y
  193. y = y.floor().long()
  194. x[x_boundary_inds] = heatmap_size - 1
  195. y[y_boundary_inds] = heatmap_size - 1
  196. # print(f'heatmaps x:{x}')
  197. # print(f'heatmaps y:{y}')
  198. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  199. vis = keypoints[..., 2] > 0
  200. valid = (valid_loc & vis).long()
  201. gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
  202. # show_heatmap(gs_heatmap[0], 'feature')
  203. # print(f'gs_heatmap:{gs_heatmap.shape}')
  204. #
  205. # lin_ind = y * heatmap_size + x
  206. # print(f'lin_ind:{lin_ind.shape}')
  207. # heatmaps = lin_ind * valid
  208. return gs_heatmap
  209. def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  210. """
  211. 为一组点生成并合并高斯热图。
  212. Args:
  213. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  214. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  215. heatmap_size (int): 热图大小 H=W
  216. sigma (float): 高斯核标准差
  217. device (str): 设备类型 ('cpu' or 'cuda')
  218. Returns:
  219. Tensor: 形状为 (H, W) 的合并后的热图
  220. """
  221. assert xs.shape == ys.shape, "x and y must have the same shape"
  222. print(f'xs:{xs.shape}')
  223. # xs=xs.squeeze(1)
  224. # ys = ys.squeeze(1)
  225. print(f'xs1:{xs.shape}')
  226. N = xs.shape[0]
  227. print(f'N:{N},num_points:{num_points}')
  228. # 创建网格
  229. grid_y, grid_x = torch.meshgrid(
  230. torch.arange(heatmap_size, device=device),
  231. torch.arange(heatmap_size, device=device),
  232. indexing='ij'
  233. )
  234. # print(f'heatmap_size:{heatmap_size}')
  235. # 初始化输出热图
  236. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  237. for i in range(N):
  238. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  239. for j in range(num_points):
  240. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  241. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  242. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  243. # 计算距离平方
  244. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  245. # 计算高斯分布
  246. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  247. heatmap+=heatmap1
  248. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  249. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  250. #
  251. # # 计算距离平方
  252. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  253. #
  254. # # 计算高斯分布
  255. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  256. #
  257. # heatmap = heatmap1 + heatmap2
  258. # 将当前热图累加到结果中
  259. combined_heatmap[i] = heatmap
  260. return combined_heatmap
  261. def generate_mask_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  262. """
  263. 为一组点生成并合并高斯热图。
  264. Args:
  265. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  266. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  267. heatmap_size (int): 热图大小 H=W
  268. sigma (float): 高斯核标准差
  269. device (str): 设备类型 ('cpu' or 'cuda')
  270. Returns:
  271. Tensor: 形状为 (H, W) 的合并后的热图
  272. """
  273. assert xs.shape == ys.shape, "x and y must have the same shape"
  274. print(f'xs:{xs.shape}')
  275. xs=xs.squeeze(1)
  276. ys = ys.squeeze(1)
  277. print(f'xs1:{xs.shape}')
  278. N = xs.shape[0]
  279. print(f'N:{N},num_points:{num_points}')
  280. # 创建网格
  281. grid_y, grid_x = torch.meshgrid(
  282. torch.arange(heatmap_size, device=device),
  283. torch.arange(heatmap_size, device=device),
  284. indexing='ij'
  285. )
  286. # print(f'heatmap_size:{heatmap_size}')
  287. # 初始化输出热图
  288. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  289. for i in range(N):
  290. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  291. for j in range(num_points):
  292. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  293. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  294. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  295. # 计算距离平方
  296. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  297. # 计算高斯分布
  298. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  299. heatmap+=heatmap1
  300. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  301. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  302. #
  303. # # 计算距离平方
  304. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  305. #
  306. # # 计算高斯分布
  307. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  308. #
  309. # heatmap = heatmap1 + heatmap2
  310. # 将当前热图累加到结果中
  311. combined_heatmap[i] = heatmap
  312. return combined_heatmap
  313. def non_maximum_suppression(a):
  314. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  315. mask = (a == ap).float().clamp(min=0.0)
  316. return a * mask
  317. def heatmaps_to_points(maps, rois,num_points=2):
  318. point_preds = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  319. point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device)
  320. print(f'heatmaps_to_lines:{maps.shape}')
  321. point_maps=maps[:,0]
  322. print(f'point_map:{point_maps.shape}')
  323. for i in range(len(rois)):
  324. point_roi_map = point_maps[i].unsqueeze(0)
  325. print(f'point_roi_map:{point_roi_map.shape}')
  326. # roi_map_probs = scores_to_probs(roi_map.copy())
  327. w = point_roi_map.shape[2]
  328. flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
  329. point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
  330. print(f'point index:{point_index}')
  331. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  332. point_x =point_index % w
  333. point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
  334. point_preds[i, 0,] = point_x
  335. point_preds[i, 1,] = point_y
  336. point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
  337. return point_preds,point_end_scores
  338. def heatmaps_to_lines(maps, rois):
  339. line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
  340. line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  341. line_maps=maps[:,1]
  342. # line_maps = maps.squeeze(1)
  343. for i in range(len(rois)):
  344. line_roi_map = line_maps[i].unsqueeze(0)
  345. print(f'line_roi_map:{line_roi_map.shape}')
  346. # roi_map_probs = scores_to_probs(roi_map.copy())
  347. w = line_roi_map.shape[1]
  348. flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1)
  349. line_score, line_index = torch.topk(flatten_line_roi_map, k=2)
  350. print(f'line index:{line_index}')
  351. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  352. pos = line_index
  353. line_x = pos % w
  354. line_y = torch.div(pos - line_x, w, rounding_mode="floor")
  355. line_preds[i, 0, :] = line_x
  356. line_preds[i, 1, :] = line_y
  357. line_preds[i, 2, :] = 1
  358. line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x]
  359. return line_preds.permute(0, 2, 1), line_end_scores
  360. # 显示热图的函数
  361. def show_heatmap(heatmap, title="Heatmap"):
  362. """
  363. 使用 matplotlib 显示热图。
  364. Args:
  365. heatmap (Tensor): 要显示的热图张量
  366. title (str): 图表标题
  367. """
  368. # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组
  369. if heatmap.is_cuda:
  370. heatmap = heatmap.cpu().numpy()
  371. else:
  372. heatmap = heatmap.numpy()
  373. plt.imshow(heatmap, cmap='hot', interpolation='nearest')
  374. plt.colorbar()
  375. plt.title(title)
  376. plt.show()
  377. def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
  378. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  379. N, K, H, W = line_logits.shape
  380. len_proposals = len(proposals)
  381. print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals},line_matched_idxs:{line_matched_idxs}')
  382. if H != W:
  383. raise ValueError(
  384. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  385. )
  386. discretization_size = H
  387. heatmaps = []
  388. gs_heatmaps = []
  389. valid = []
  390. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
  391. print(f'line_proposals_per_image:{proposals_per_image.shape}')
  392. print(f'gt_lines:{gt_lines}')
  393. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  394. kp = gt_kp_in_image[midx]
  395. gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
  396. gs_heatmaps.append(gs_heatmaps_per_img)
  397. # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
  398. # heatmaps.append(heatmaps_per_image.view(-1))
  399. # valid.append(valid_per_image.view(-1))
  400. # line_targets = torch.cat(heatmaps, dim=0)
  401. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  402. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  403. # print(f'line_targets:{line_targets.shape},{line_targets}')
  404. # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  405. # valid = torch.where(valid)[0]
  406. # print(f' line_targets[valid]:{line_targets[valid]}')
  407. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  408. # accept empty tensors, so handle it sepaartely
  409. # if line_targets.numel() == 0 or len(valid) == 0:
  410. # return line_logits.sum() * 0
  411. # line_logits = line_logits.view(N * K, H * W)
  412. # print(f'line_logits[valid]:{line_logits[valid].shape}')
  413. print(f'loss1 line_logits:{line_logits.shape}')
  414. line_logits = line_logits[:,1,:,:]
  415. # line_logits = line_logits.squeeze(1)
  416. print(f'loss2 line_logits:{line_logits.shape}')
  417. # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
  418. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  419. return line_loss
  420. def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
  421. print(f'compute_arc_loss:{feature_logits.shape}')
  422. N, K, H, W = feature_logits.shape
  423. len_proposals = len(proposals)
  424. empty_count = 0
  425. non_empty_count = 0
  426. for prop in proposals:
  427. if prop.shape[0] == 0:
  428. empty_count += 1
  429. else:
  430. non_empty_count += 1
  431. print(f"Empty proposals count: {empty_count}")
  432. print(f"Non-empty proposals count: {non_empty_count}")
  433. print(f'starte to compute_point_loss')
  434. print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}')
  435. if H != W:
  436. raise ValueError(
  437. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  438. )
  439. discretization_size = H
  440. gs_heatmaps = []
  441. # print(f'point_matched_idxs:{point_matched_idxs}')
  442. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs):
  443. # [
  444. # (Tensor(38, 4), Tensor(1, 57, 2), Tensor(38, 1)),
  445. # (Tensor(65, 4), Tensor(1, 74, 2), Tensor(65, 1))
  446. # ]
  447. print(f'proposals_per_image:{proposals_per_image.shape}')
  448. kp = gt_kp_in_image[midx]
  449. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  450. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  451. gs_heatmaps_per_img = arc_points_to_heatmap(kp, proposals_per_image, discretization_size)
  452. gs_heatmaps.append(gs_heatmaps_per_img)
  453. if len(gs_heatmaps)>0:
  454. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  455. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}')
  456. line_logits = feature_logits.squeeze(1)
  457. print(f'single_point_logits:{line_logits.shape}')
  458. # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
  459. # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  460. line_loss=combined_loss(line_logits, gs_heatmaps)
  461. else:
  462. line_loss=100
  463. print("d")
  464. return line_loss
  465. def arc_points_to_heatmap(keypoints, rois, heatmap_size):
  466. print(f'rois:{rois.shape}')
  467. print(f'heatmap_size:{heatmap_size}')
  468. print(f'keypoints.shape:{keypoints.shape}')
  469. # batch_size, num_keypoints, _ = keypoints.shape
  470. x = keypoints[..., 0].unsqueeze(1)
  471. y = keypoints[..., 1].unsqueeze(1)
  472. num_points=x.shape[2]
  473. print(f'num_points:{num_points}')
  474. gs = generate_mask_gaussian_heatmaps(x, y, num_points=num_points, heatmap_size=heatmap_size, sigma=2.0)
  475. # show_heatmap(gs[0],'target')
  476. all_roi_heatmap = []
  477. for roi, heatmap in zip(rois, gs):
  478. # show_heatmap(heatmap, 'target')
  479. print(f'heatmap:{heatmap.shape}')
  480. heatmap = heatmap.unsqueeze(0)
  481. x1, y1, x2, y2 = map(int, roi)
  482. roi_heatmap = torch.zeros_like(heatmap)
  483. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  484. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  485. all_roi_heatmap.append(roi_heatmap)
  486. all_roi_heatmap = torch.cat(all_roi_heatmap)
  487. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  488. return all_roi_heatmap
  489. def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
  490. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  491. N, K, H, W = line_logits.shape
  492. len_proposals = len(proposals)
  493. empty_count = 0
  494. non_empty_count = 0
  495. for prop in proposals:
  496. if prop.shape[0] == 0:
  497. empty_count += 1
  498. else:
  499. non_empty_count += 1
  500. print(f"Empty proposals count: {empty_count}")
  501. print(f"Non-empty proposals count: {non_empty_count}")
  502. print(f'starte to compute_point_loss')
  503. print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
  504. if H != W:
  505. raise ValueError(
  506. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  507. )
  508. discretization_size = H
  509. gs_heatmaps = []
  510. # print(f'point_matched_idxs:{point_matched_idxs}')
  511. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs):
  512. print(f'proposals_per_image:{proposals_per_image.shape}')
  513. kp = gt_kp_in_image[midx]
  514. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  515. gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size)
  516. gs_heatmaps.append(gs_heatmaps_per_img)
  517. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  518. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  519. line_logits = line_logits[:,0]
  520. print(f'single_point_logits:{line_logits.shape}')
  521. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  522. return line_loss
  523. def compute_circle_loss(circle_logits, proposals, gt_circles, circle_matched_idxs):
  524. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  525. N, K, H, W = circle_logits.shape
  526. len_proposals = len(proposals)
  527. empty_count = 0
  528. non_empty_count = 0
  529. for prop in proposals:
  530. if prop.shape[0] == 0:
  531. empty_count += 1
  532. else:
  533. non_empty_count += 1
  534. print(f"Empty proposals count: {empty_count}")
  535. print(f"Non-empty proposals count: {non_empty_count}")
  536. print(f'starte to compute_circle_loss')
  537. print(f'compute_circle_loss circle_logits.shape:{circle_logits.shape},len_proposals:{len_proposals}')
  538. if H != W:
  539. raise ValueError(
  540. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  541. )
  542. discretization_size = H
  543. gs_heatmaps = []
  544. # print(f'point_matched_idxs:{point_matched_idxs}')
  545. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_circles, circle_matched_idxs):
  546. print(f'proposals_per_image:{proposals_per_image.shape}')
  547. kp = gt_kp_in_image[midx]
  548. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  549. gs_heatmaps_per_img = points_to_heatmap(kp, proposals_per_image,num_points=4, heatmap_size=discretization_size)
  550. gs_heatmaps.append(gs_heatmaps_per_img)
  551. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  552. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{circle_logits.squeeze(1).shape}')
  553. circle_logits = circle_logits[:, 0]
  554. print(f'circle_logits:{circle_logits.shape}')
  555. circle_loss = F.cross_entropy(circle_logits, gs_heatmaps)
  556. return circle_loss
  557. def lines_to_boxes(lines, img_size=511):
  558. """
  559. 输入:
  560. lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
  561. img_size: int,图像尺寸,用于 clamp 边界
  562. 输出:
  563. boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
  564. """
  565. # 提取所有线段的两个端点
  566. p1 = lines[:, 0] # (N, 2)
  567. p2 = lines[:, 1] # (N, 2)
  568. # 每条线段的 x 和 y 坐标
  569. x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1) # (N, 2)
  570. y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1) # (N, 2)
  571. # 计算包围盒边界
  572. x_min = x_coords.min(dim=1).values
  573. y_min = y_coords.min(dim=1).values
  574. x_max = x_coords.max(dim=1).values
  575. y_max = y_coords.max(dim=1).values
  576. # 扩展边界并限制在图像范围内
  577. x_min = (x_min - 1).clamp(min=0, max=img_size)
  578. y_min = (y_min - 1).clamp(min=0, max=img_size)
  579. x_max = (x_max + 1).clamp(min=0, max=img_size)
  580. y_max = (y_max + 1).clamp(min=0, max=img_size)
  581. # 合成包围盒
  582. boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1) # (N, 4)
  583. return boxes
  584. def box_iou_pairwise(box1, box2):
  585. """
  586. 输入:
  587. box1: shape (N, 4)
  588. box2: shape (M, 4)
  589. 输出:
  590. ious: shape (min(N, M), ), 只计算 i = j 的配对
  591. """
  592. N = min(len(box1), len(box2))
  593. lt = torch.max(box1[:N, :2], box2[:N, :2]) # 左上角
  594. rb = torch.min(box1[:N, 2:], box2[:N, 2:]) # 右下角
  595. wh = (rb - lt).clamp(min=0) # 宽高
  596. inter_area = wh[:, 0] * wh[:, 1] # 交集面积
  597. area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
  598. area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
  599. union_area = area1 + area2 - inter_area
  600. ious = inter_area / (union_area + 1e-6)
  601. return ious
  602. def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0):
  603. """
  604. Args:
  605. x: [N,1,H,W] 热力图
  606. boxes: [N,4] 框坐标
  607. gt_lines: [N,2,3] GT线段(含可见性)
  608. matched_idx: 匹配 index
  609. img_size: 图像尺寸
  610. alpha: IoU 损失权重
  611. beta: 长度损失权重
  612. gamma: 方向角度损失权重
  613. """
  614. losses = []
  615. boxes_per_image = [box.size(0) for box in boxes]
  616. x2 = x.split(boxes_per_image, dim=0)
  617. for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
  618. p_prob, _ = heatmaps_to_lines(xx, bb)
  619. pred_lines = p_prob
  620. gt_line_points = gt_line[mid]
  621. if len(pred_lines) == 0 or len(gt_line_points) == 0:
  622. continue
  623. # IoU 损失
  624. pred_boxes = lines_to_boxes(pred_lines, img_size)
  625. gt_boxes = lines_to_boxes(gt_line_points, img_size)
  626. ious = box_iou_pairwise(pred_boxes, gt_boxes)
  627. iou_loss = 1.0 - ious # [N]
  628. # 长度损失
  629. pred_len = line_length(pred_lines)
  630. gt_len = line_length(gt_line_points)
  631. length_diff = F.l1_loss(pred_len, gt_len, reduction='none') # [N]
  632. # 方向角度损失
  633. pred_dir = line_direction(pred_lines)
  634. gt_dir = line_direction(gt_line_points)
  635. ang_loss = angle_loss_cosine(pred_dir, gt_dir) # [N]
  636. # 归一化每一项损失
  637. norm_iou = normalize_tensor(iou_loss)
  638. norm_len = normalize_tensor(length_diff)
  639. norm_ang = normalize_tensor(ang_loss)
  640. total = alpha * norm_iou + beta * norm_len + gamma * norm_ang
  641. losses.append(total)
  642. if not losses:
  643. return None
  644. return torch.mean(torch.cat(losses))
  645. def point_inference(x, point_boxes):
  646. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  647. points_probs = []
  648. points_scores = []
  649. boxes_per_image = [box.size(0) for box in point_boxes]
  650. x2 = x.split(boxes_per_image, dim=0)
  651. for xx, bb in zip(x2, point_boxes):
  652. point_prob,point_scores = heatmaps_to_points(xx, bb,num_points=1)
  653. points_probs.append(point_prob.unsqueeze(1))
  654. points_scores.append(point_scores)
  655. return points_probs,points_scores
  656. def circle_inference(x, point_boxes):
  657. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  658. points_probs = []
  659. points_scores = []
  660. boxes_per_image = [box.size(0) for box in point_boxes]
  661. x2 = x.split(boxes_per_image, dim=0)
  662. for xx, bb in zip(x2, point_boxes):
  663. point_prob,point_scores = heatmaps_to_points(xx, bb,num_points=4)
  664. points_probs.append(point_prob.unsqueeze(1))
  665. points_scores.append(point_scores)
  666. return points_probs,points_scores
  667. def line_inference(x, line_boxes):
  668. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  669. lines_probs = []
  670. lines_scores = []
  671. boxes_per_image = [box.size(0) for box in line_boxes]
  672. x2 = x.split(boxes_per_image, dim=0)
  673. # x2:tuple 2 x2[0]:[1,3,1024,1024]
  674. # line_box: list:2 [1,4] [1.4] fasterrcnn kuang
  675. for xx, bb in zip(x2, line_boxes):
  676. line_prob, line_scores, = heatmaps_to_lines(xx, bb)
  677. lines_probs.append(line_prob)
  678. lines_scores.append(line_scores)
  679. return lines_probs, lines_scores
  680. def arc_inference(x, arc_boxes,th):
  681. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  682. points_probs = []
  683. points_scores = []
  684. print(f'arc_boxes:{len(arc_boxes)}')
  685. boxes_per_image = [box.size(0) for box in arc_boxes]
  686. print(f'arc boxes_per_image:{boxes_per_image}')
  687. x2 = x.split(boxes_per_image, dim=0)
  688. for xx, bb in zip(x2, arc_boxes):
  689. point_prob,point_scores = heatmaps_to_arc(xx, bb)
  690. points_probs.append(point_prob.unsqueeze(1))
  691. points_scores.append(point_scores)
  692. points_probs_tensor=torch.cat(points_probs)
  693. print(f'points_probs shape:{points_probs_tensor.shape}')
  694. feature_logits = x
  695. batch_size = feature_logits.shape[0]
  696. num_proposals = len(arc_boxes[0])
  697. results = [[torch.empty(0, 2) for _ in range(num_proposals)] for _ in range(batch_size)]
  698. proposals_list = arc_boxes[0] # [[tensor(...)]]
  699. for proposal_idx, proposal in enumerate(proposals_list):
  700. coords = proposal.tolist()
  701. x1, y1, x2, y2 = map(int, coords)
  702. x1 = max(0, x1)
  703. y1 = max(0, y1)
  704. x2 = min(feature_logits.shape[3], x2)
  705. y2 = min(feature_logits.shape[2], y2)
  706. for batch_idx in range(batch_size):
  707. region = feature_logits[batch_idx, :, y1:y2, x1:x2]
  708. mask = region > th
  709. coords = torch.nonzero(mask)
  710. if coords.numel() > 0:
  711. # 取 (y, x),然后转换为全局坐标 (x, y)
  712. local_coords = coords[:, [2, 1]] # (x, y)
  713. local_coords[:, 0] += x1
  714. local_coords[:, 1] += y1
  715. results[batch_idx][proposal_idx] = local_coords
  716. print(f're:{results}')
  717. return points_probs,points_scores,results
  718. import torch.nn.functional as F
  719. def heatmaps_to_arc(maps, rois, threshold=0.1, output_size=(128, 128)):
  720. """
  721. Args:
  722. maps: [N, 3, H, W] - full heatmaps
  723. rois: [N, 4] - bounding boxes
  724. threshold: float - binarization threshold
  725. output_size: resized size for uniform NMS
  726. Returns:
  727. masks: [N, 1, H, W] - binary mask aligned with input map
  728. scores: [N, 1] - count of non-zero pixels in each mask
  729. """
  730. N, _, H, W = maps.shape
  731. masks = torch.zeros((N, 1, H, W), dtype=torch.float32, device=maps.device)
  732. scores = torch.zeros((N, 1), dtype=torch.float32, device=maps.device)
  733. point_maps = maps[:, 0] # È¡µÚÒ»¸öͨµÀ [N, H, W]
  734. print(f"==> heatmaps_to_arc: maps.shape = {maps.shape}, rois.shape = {rois.shape}")
  735. for i in range(N):
  736. x1, y1, x2, y2 = rois[i].long()
  737. x1 = x1.clamp(0, W - 1)
  738. x2 = x2.clamp(0, W - 1)
  739. y1 = y1.clamp(0, H - 1)
  740. y2 = y2.clamp(0, H - 1)
  741. print(f"[{i}] roi: ({x1.item()}, {y1.item()}, {x2.item()}, {y2.item()})")
  742. if x2 <= x1 or y2 <= y1:
  743. print(f" Skipped invalid ROI at index {i}")
  744. continue
  745. roi_map = point_maps[i, y1:y2, x1:x2] # [h, w]
  746. print(f" roi_map.shape: {roi_map.shape}")
  747. if roi_map.numel() == 0:
  748. print(f" Skipped empty ROI at index {i}")
  749. continue
  750. # resize to uniform size
  751. roi_map_resized = F.interpolate(
  752. roi_map.unsqueeze(0).unsqueeze(0),
  753. size=output_size,
  754. mode='bilinear',
  755. align_corners=False
  756. ) # [1, 1, H, W]
  757. print(f" roi_map_resized.shape: {roi_map_resized.shape}")
  758. # NMS + threshold
  759. nms_roi = non_maximum_suppression(roi_map_resized) # shape: [1, H, W]
  760. bin_mask = (nms_roi > threshold).float() # shape: [1, H, W]
  761. print(f" bin_mask.sum(): {bin_mask.sum().item()}")
  762. # resize back to original roi size
  763. h = int((y2 - y1).item())
  764. w = int((x2 - x1).item())
  765. # È·±£ bin_mask ÊÇ [1, 128, 128]
  766. assert bin_mask.dim() == 4, f"Expected 3D tensor [1, H, W], got {bin_mask.shape}"
  767. # ÉϲÉÑù»Ø ROI ԭʼ´óС
  768. bin_mask_original_size = F.interpolate(
  769. # bin_mask.unsqueeze(0), # ? [1, 1, 128, 128]
  770. bin_mask, # ? [1, 1, 128, 128]
  771. size=(h, w),
  772. mode='bilinear',
  773. align_corners=False
  774. )[0] # ? [1, h, w]
  775. masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size.squeeze()
  776. scores[i] = bin_mask_original_size.sum()
  777. print(f" bin_mask_original_size.shape: {bin_mask_original_size.shape}, sum: {scores[i].item()}")
  778. print(f"==> Done. Total valid masks: {(scores > 0).sum().item()} / {N}")
  779. return masks, scores