head_losses.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313
  1. import torch
  2. from matplotlib import pyplot as plt
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from torch.cuda import device
  6. class DiceLoss(nn.Module):
  7. def __init__(self, smooth=1.):
  8. super(DiceLoss, self).__init__()
  9. self.smooth = smooth
  10. def forward(self, logits, targets):
  11. probs = torch.sigmoid(logits)
  12. probs = probs.view(-1)
  13. targets = targets.view(-1).float()
  14. intersection = (probs * targets).sum()
  15. dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
  16. return 1. - dice
  17. bce_loss = nn.BCEWithLogitsLoss()
  18. dice_loss = DiceLoss()
  19. def combined_loss(preds, targets, alpha=0.5):
  20. bce = bce_loss(preds, targets)
  21. d = dice_loss(preds, targets)
  22. return alpha * bce + (1 - alpha) * d
  23. def features_align(features, proposals, img_size):
  24. print(f'features_align features:{features.shape},proposals:{len(proposals)}')
  25. align_feat_list = []
  26. for feat, proposals_per_img in zip(features, proposals):
  27. print(f'feature_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
  28. if proposals_per_img.shape[0]>0:
  29. feat = feat.unsqueeze(0)
  30. for proposal in proposals_per_img:
  31. align_feat = torch.zeros_like(feat)
  32. # print(f'align_feat:{align_feat.shape}')
  33. x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
  34. # 将每个proposal框内的部分赋值到align_feats对应位置
  35. align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
  36. align_feat_list.append(align_feat)
  37. # print(f'align_feat_list:{align_feat_list}')
  38. if len(align_feat_list) > 0:
  39. feats_tensor = torch.cat(align_feat_list)
  40. print(f'align features :{feats_tensor.shape}')
  41. else:
  42. feats_tensor = None
  43. return feats_tensor
  44. def normalize_tensor(t):
  45. return (t - t.min()) / (t.max() - t.min() + 1e-6)
  46. def line_length(lines):
  47. """
  48. 计算每条线段的长度
  49. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  50. 返回: [N]
  51. """
  52. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  53. def line_direction(lines):
  54. """
  55. 计算每条线段的单位方向向量
  56. lines: [N, 2, 2]
  57. 返回: [N, 2] 单位方向向量
  58. """
  59. vec = lines[:, 1] - lines[:, 0]
  60. return F.normalize(vec, dim=-1)
  61. def angle_loss_cosine(pred_dir, gt_dir):
  62. """
  63. 使用 cosine similarity 计算方向差异
  64. pred_dir: [N, 2]
  65. gt_dir: [N, 2]
  66. 返回: [N]
  67. """
  68. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  69. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  70. def line_length(lines):
  71. """
  72. 计算每条线段的长度
  73. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  74. 返回: [N]
  75. """
  76. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  77. def line_direction(lines):
  78. """
  79. 计算每条线段的单位方向向量
  80. lines: [N, 2, 2]
  81. 返回: [N, 2] 单位方向向量
  82. """
  83. vec = lines[:, 1] - lines[:, 0]
  84. return F.normalize(vec, dim=-1)
  85. def angle_loss_cosine(pred_dir, gt_dir):
  86. """
  87. 使用 cosine similarity 计算方向差异
  88. pred_dir: [N, 2]
  89. gt_dir: [N, 2]
  90. 返回: [N]
  91. """
  92. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  93. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  94. def single_point_to_heatmap(keypoints, rois, heatmap_size):
  95. # type: (Tensor, Tensor, int) -> Tensor
  96. print(f'rois:{rois.shape}')
  97. print(f'heatmap_size:{heatmap_size}')
  98. print(f'keypoints.shape:{keypoints.shape}')
  99. # batch_size, num_keypoints, _ = keypoints.shape
  100. x = keypoints[..., 0].unsqueeze(1)
  101. y = keypoints[..., 1].unsqueeze(1)
  102. gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
  103. # show_heatmap(gs[0],'target')
  104. all_roi_heatmap = []
  105. for roi, heatmap in zip(rois, gs):
  106. # show_heatmap(heatmap, 'target')
  107. # print(f'heatmap:{heatmap.shape}')
  108. heatmap = heatmap.unsqueeze(0)
  109. x1, y1, x2, y2 = map(int, roi)
  110. roi_heatmap = torch.zeros_like(heatmap)
  111. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  112. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  113. all_roi_heatmap.append(roi_heatmap)
  114. all_roi_heatmap = torch.cat(all_roi_heatmap)
  115. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  116. return all_roi_heatmap
  117. def arc_inference1(arc_equation, x, arc_boxes, th):
  118. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  119. points_probs = []
  120. points_scores = []
  121. print(f'arc_boxes:{len(arc_boxes)}')
  122. boxes_per_image = [box.size(0) for box in arc_boxes]
  123. print(f'arc boxes_per_image:{boxes_per_image}')
  124. x2 = x.split(boxes_per_image, dim=0)
  125. arc7 = arc_equation.split(boxes_per_image, dim=0)
  126. # print(f'arc7:{arc7}')
  127. for xx, bb in zip(x2, arc_boxes):
  128. point_prob, point_scores = heatmaps_to_arc(xx, bb)
  129. points_probs.append(point_prob.unsqueeze(1))
  130. points_scores.append(point_scores)
  131. return arc7, points_scores
  132. def points_to_heatmap(keypoints, rois,num_points=2, heatmap_size=(512,512)):
  133. # type: (Tensor, Tensor, int) -> Tensor
  134. print(f'rois:{rois.shape}')
  135. print(f'heatmap_size:{heatmap_size}')
  136. print(f'keypoints.shape:{keypoints.shape}')
  137. # batch_size, num_keypoints, _ = keypoints.shape
  138. x = keypoints[..., 0].unsqueeze(1)
  139. y = keypoints[..., 1].unsqueeze(1)
  140. gs = generate_gaussian_heatmaps(x, y,num_points=num_points, heatmap_size=heatmap_size, sigma=2.0)
  141. # show_heatmap(gs[0],'target')
  142. all_roi_heatmap = []
  143. for roi, heatmap in zip(rois, gs):
  144. # show_heatmap(heatmap, 'target')
  145. # print(f'heatmap:{heatmap.shape}')
  146. heatmap = heatmap.unsqueeze(0)
  147. x1, y1, x2, y2 = map(int, roi)
  148. roi_heatmap = torch.zeros_like(heatmap)
  149. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  150. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  151. all_roi_heatmap.append(roi_heatmap)
  152. all_roi_heatmap = torch.cat(all_roi_heatmap)
  153. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  154. return all_roi_heatmap
  155. def line_points_to_heatmap(keypoints, rois, heatmap_size):
  156. # type: (Tensor, Tensor, int) -> Tensor
  157. print(f'rois:{rois.shape}')
  158. print(f'heatmap_size:{heatmap_size}')
  159. print(f'keypoints.shape:{keypoints.shape}')
  160. # batch_size, num_keypoints, _ = keypoints.shape
  161. x = keypoints[..., 0]
  162. y = keypoints[..., 1]
  163. gs = generate_gaussian_heatmaps(x, y, heatmap_size,num_points=2, sigma=1.0)
  164. # show_heatmap(gs[0],'target')
  165. all_roi_heatmap = []
  166. for roi, heatmap in zip(rois, gs):
  167. # print(f'heatmap:{heatmap.shape}')
  168. # show_heatmap(heatmap,'target')
  169. heatmap = heatmap.unsqueeze(0)
  170. x1, y1, x2, y2 = map(int, roi)
  171. roi_heatmap = torch.zeros_like(heatmap)
  172. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  173. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  174. all_roi_heatmap.append(roi_heatmap)
  175. if len(all_roi_heatmap) > 0:
  176. all_roi_heatmap = torch.cat(all_roi_heatmap)
  177. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  178. else:
  179. all_roi_heatmap = None
  180. return all_roi_heatmap
  181. """
  182. 修改适配的原结构的点 转热图,适用于带roi_pool版本的
  183. """
  184. def line_points_to_heatmap_(keypoints, rois, heatmap_size):
  185. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  186. print(f'rois:{rois.shape}')
  187. print(f'heatmap_size:{heatmap_size}')
  188. offset_x = rois[:, 0]
  189. offset_y = rois[:, 1]
  190. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  191. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  192. offset_x = offset_x[:, None]
  193. offset_y = offset_y[:, None]
  194. scale_x = scale_x[:, None]
  195. scale_y = scale_y[:, None]
  196. print(f'keypoints.shape:{keypoints.shape}')
  197. # batch_size, num_keypoints, _ = keypoints.shape
  198. x = keypoints[..., 0]
  199. y = keypoints[..., 1]
  200. # gs=generate_gaussian_heatmaps(x,y,512,1.0)
  201. # print(f'gs_heatmap shape:{gs.shape}')
  202. #
  203. # show_heatmap(gs[0],'target')
  204. x_boundary_inds = x == rois[:, 2][:, None]
  205. y_boundary_inds = y == rois[:, 3][:, None]
  206. x = (x - offset_x) * scale_x
  207. x = x.floor().long()
  208. y = (y - offset_y) * scale_y
  209. y = y.floor().long()
  210. x[x_boundary_inds] = heatmap_size - 1
  211. y[y_boundary_inds] = heatmap_size - 1
  212. # print(f'heatmaps x:{x}')
  213. # print(f'heatmaps y:{y}')
  214. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  215. vis = keypoints[..., 2] > 0
  216. valid = (valid_loc & vis).long()
  217. gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
  218. # show_heatmap(gs_heatmap[0], 'feature')
  219. # print(f'gs_heatmap:{gs_heatmap.shape}')
  220. #
  221. # lin_ind = y * heatmap_size + x
  222. # print(f'lin_ind:{lin_ind.shape}')
  223. # heatmaps = lin_ind * valid
  224. return gs_heatmap
  225. def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  226. """
  227. 为一组点生成并合并高斯热图。
  228. Args:
  229. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  230. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  231. heatmap_size (int): 热图大小 H=W
  232. sigma (float): 高斯核标准差
  233. device (str): 设备类型 ('cpu' or 'cuda')
  234. Returns:
  235. Tensor: 形状为 (H, W) 的合并后的热图
  236. """
  237. assert xs.shape == ys.shape, "x and y must have the same shape"
  238. print(f'xs:{xs.shape}')
  239. xs=xs.squeeze(1)
  240. ys = ys.squeeze(1)
  241. print(f'xs1:{xs.shape}')
  242. N = xs.shape[0]
  243. print(f'N:{N},num_points:{num_points}')
  244. # 创建网格
  245. grid_y, grid_x = torch.meshgrid(
  246. torch.arange(heatmap_size, device=device),
  247. torch.arange(heatmap_size, device=device),
  248. indexing='ij'
  249. )
  250. # print(f'heatmap_size:{heatmap_size}')
  251. # 初始化输出热图
  252. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  253. for i in range(N):
  254. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  255. for j in range(num_points):
  256. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  257. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  258. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  259. # 计算距离平方
  260. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  261. # 计算高斯分布
  262. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  263. heatmap+=heatmap1
  264. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  265. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  266. #
  267. # # 计算距离平方
  268. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  269. #
  270. # # 计算高斯分布
  271. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  272. #
  273. # heatmap = heatmap1 + heatmap2
  274. # 将当前热图累加到结果中
  275. combined_heatmap[i] = heatmap
  276. return combined_heatmap
  277. def generate_mask_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  278. """
  279. 为一组点生成并合并高斯热图。
  280. Args:
  281. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  282. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  283. heatmap_size (int): 热图大小 H=W
  284. sigma (float): 高斯核标准差
  285. device (str): 设备类型 ('cpu' or 'cuda')
  286. Returns:
  287. Tensor: 形状为 (H, W) 的合并后的热图
  288. """
  289. assert xs.shape == ys.shape, "x and y must have the same shape"
  290. print(f'xs:{xs.shape}')
  291. xs=xs.squeeze(1)
  292. ys = ys.squeeze(1)
  293. print(f'xs1:{xs.shape}')
  294. N = xs.shape[0]
  295. print(f'N:{N},num_points:{num_points}')
  296. # 创建网格
  297. grid_y, grid_x = torch.meshgrid(
  298. torch.arange(heatmap_size, device=device),
  299. torch.arange(heatmap_size, device=device),
  300. indexing='ij'
  301. )
  302. # print(f'heatmap_size:{heatmap_size}')
  303. # 初始化输出热图
  304. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  305. for i in range(N):
  306. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  307. for j in range(num_points):
  308. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  309. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  310. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  311. # 计算距离平方
  312. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  313. # 计算高斯分布
  314. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  315. heatmap+=heatmap1
  316. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  317. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  318. #
  319. # # 计算距离平方
  320. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  321. #
  322. # # 计算高斯分布
  323. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  324. #
  325. # heatmap = heatmap1 + heatmap2
  326. # 将当前热图累加到结果中
  327. combined_heatmap[i] = heatmap
  328. return combined_heatmap
  329. def non_maximum_suppression(a):
  330. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  331. mask = (a == ap).float().clamp(min=0.0)
  332. return a * mask
  333. def heatmaps_to_points(maps, rois,num_points=2):
  334. point_preds = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  335. point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device)
  336. print(f'heatmaps_to_lines:{maps.shape}')
  337. point_maps=maps[:,0]
  338. print(f'point_map:{point_maps.shape}')
  339. for i in range(len(rois)):
  340. point_roi_map = point_maps[i].unsqueeze(0)
  341. print(f'point_roi_map:{point_roi_map.shape}')
  342. # roi_map_probs = scores_to_probs(roi_map.copy())
  343. w = point_roi_map.shape[2]
  344. flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
  345. point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
  346. print(f'point index:{point_index}')
  347. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  348. point_x =point_index % w
  349. point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
  350. point_preds[i, 0,] = point_x
  351. point_preds[i, 1,] = point_y
  352. point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
  353. return point_preds,point_end_scores
  354. # 分4块
  355. def find_max_heat_point_in_each_part(feature_map, box):
  356. """
  357. 在给定的特征图上,根据box中心点往上移3,往右移3作为新的中心点,
  358. 并将特征图划分为4个部分,之后在每个部分中找到热度值最大的点。
  359. Args:
  360. feature_map (torch.Tensor): 形状为 [C, H, W] 的特征图
  361. box (torch.Tensor): 形状为 [4] 的边界框 [x_min, y_min, x_max, y_max]
  362. Returns:
  363. list: 每个区域中热度最高的点的位置和其对应的热度值 [(y1, x1, heat1), ..., (y4, x4, heat4)]
  364. """
  365. device = feature_map.device
  366. C, H, W = feature_map.shape
  367. # 计算box的中心点(cx, cy)
  368. cx = (box[0] + box[2]) // 2
  369. cy = (box[1] + box[3]) // 2
  370. # 偏移中心点
  371. new_cx = min(max(cx + 3, 0), W - 1) # 向右移3
  372. new_cy = min(max(cy - 3, 0), H - 1) # 向上移3
  373. # 创建坐标网格
  374. y_coords, x_coords = torch.meshgrid(
  375. torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij'
  376. )
  377. # 划分四个区域
  378. mask_q1 = (y_coords < new_cy) & (x_coords < new_cx) # 左上
  379. mask_q2 = (y_coords < new_cy) & (x_coords >= new_cx) # 右上
  380. mask_q3 = (y_coords >= new_cy) & (x_coords < new_cx) # 左下
  381. mask_q4 = (y_coords >= new_cy) & (x_coords >= new_cx) # 右下
  382. # def process_region(ins):
  383. # region = feature_map[:, :, ins].squeeze()
  384. # if len(region.shape) == 0: # 如果区域为空,则跳过
  385. # return None, None
  386. # # 找到最大热度值的点及其位置
  387. # (y, x), heat_val = non_maximum_suppression(region[0])
  388. # # 将相对坐标转换回全局坐标
  389. # y_global = y + torch.where(ins)[0].min().item()
  390. # x_global = x + torch.where(ins)[1].min().item()
  391. # return (y_global, x_global), heat_val
  392. #
  393. # results = []
  394. # for ins in [mask_q1, mask_q2, mask_q3, mask_q4]:
  395. # point, heat_val = process_region(ins)
  396. # if point is not None:
  397. # # results.append((point[0], point[1], heat_val))
  398. # results.append((point[0], point[1]))
  399. # else:
  400. # results.append(None)
  401. masks = [mask_q1, mask_q2, mask_q3, mask_q4]
  402. results = []
  403. # 假设使用第一个通道作为热力图
  404. heatmap = feature_map[0] # [H, W]
  405. def process_region(mask):
  406. # 应用 ins,只保留该区域
  407. masked_heatmap = heatmap.clone() # 复制以避免修改原数据
  408. masked_heatmap[~mask] = 0 # 非区域置0
  409. def non_maximum_suppression_2d(heatmap, kernel_size=3):
  410. """
  411. 对 2D 热力图做非极大值抑制,保留局部最大值点。
  412. Args:
  413. heatmap (torch.Tensor): [H, W],输入热力图
  414. kernel_size (int): 池化窗口大小,用于比较是否为局部最大值
  415. Returns:
  416. torch.Tensor: 与 heatmap 同形状的 ins,局部最大值位置为 True
  417. """
  418. pad = (kernel_size - 1) // 2
  419. max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
  420. maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
  421. # 局部最大值且值大于0
  422. peaks = (heatmap == maxima) & (heatmap > 0)
  423. return peaks
  424. # 1. 先做 NMS 得到候选局部极大值点
  425. nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3) # [H, W] bool
  426. candidate_peaks = masked_heatmap * nms_mask.float() # 只保留 NMS 后的峰值
  427. # 2. 找出所有候选点中值最大的一个
  428. if candidate_peaks.max() <= 0:
  429. return None
  430. # 找到最大值的位置
  431. max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
  432. y, x = divmod(max_idx.item(), W)
  433. return (x, y) # 返回 (y, x)
  434. for mask in masks:
  435. point = process_region(mask)
  436. results.append(point)
  437. return results
  438. def non_maximum_suppression_2d(heatmap, kernel_size=3):
  439. pad = (kernel_size - 1) // 2
  440. max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
  441. maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
  442. peaks = (heatmap == maxima) & (heatmap > 0)
  443. return peaks
  444. def find_max_heat_point_in_edge_centers(feature_map, box):
  445. device = feature_map.device
  446. C, H, W = feature_map.shape
  447. # ¼ÆËã box ÖÐÐÄ
  448. cx = (box[0] + box[2]) / 2
  449. cy = (box[1] + box[3]) / 2
  450. # ¸ù¾Ý box ¿í¸ß¼ÆËã¾Å¹¬¸ñ·Ö½çÏß
  451. box_width = box[2] - box[0]
  452. box_height = box[3] - box[1]
  453. x_left = cx - box_width / 6
  454. x_right = cx + box_width / 6
  455. y_top = cy - box_height / 6
  456. y_bottom = cy + box_height / 6
  457. # ´´½¨Íø¸ñ
  458. y_coords, x_coords = torch.meshgrid(
  459. torch.arange(H, device=device),
  460. torch.arange(W, device=device),
  461. indexing='ij'
  462. )
  463. # ¶¨ÒåËĸö¡°±ßÖС±ÇøÓòµÄ ins
  464. mask1 = (x_coords < x_left) & (y_coords < y_top)
  465. mask_top_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords < y_top)
  466. mask3 = (x_coords >= x_right) & (y_coords < y_top)
  467. mask_left_middle = (x_coords < x_left) & (y_coords >= y_top) & (y_coords < y_bottom)
  468. mask_right_middle = (x_coords >= x_right) & (y_coords >= y_top) & (y_coords < y_bottom)
  469. mask4 = (x_coords < x_left) & (y_coords >= y_bottom)
  470. mask_bottom_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords >= y_bottom)
  471. mask_right_bottom = (x_coords >= x_right) & (y_coords >= y_bottom)
  472. # masks = [
  473. # # mask1,
  474. # mask_top_middle,
  475. # # mask3,
  476. # mask_left_middle,
  477. # mask_right_middle,
  478. # # mask4,
  479. # mask_bottom_middle,
  480. # mask_right_bottom
  481. # ]
  482. masks = [
  483. mask_top_middle,
  484. mask_right_middle,
  485. mask_bottom_middle,
  486. mask_left_middle
  487. ]
  488. # ʹÓõÚÒ»¸öͨµÀ×÷ΪÈÈÁ¦Í¼
  489. heatmap = feature_map[0] # [H, W]
  490. results = []
  491. for mask in masks:
  492. masked_heatmap = heatmap.clone()
  493. masked_heatmap[~mask] = 0 # ·ÇÄ¿±êÇøÓòÖà 0
  494. # # NMS ÒÖÖÆ
  495. # nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3)
  496. # candidate_peaks = masked_heatmap * nms_mask.float()
  497. #
  498. # if candidate_peaks.max() <= 0:
  499. # results.append(None)
  500. # continue
  501. #
  502. # # ÕÒ×î´óֵλÖÃ
  503. # max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
  504. # y, x = divmod(max_idx.item(), W)
  505. flatten_point_roi_map = masked_heatmap.reshape(1, -1)
  506. point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
  507. point_x =point_index % W
  508. point_y = torch.div(point_index - point_x, W, rounding_mode="floor")
  509. results.append((point_x, point_y))
  510. return results # [(y_top, x_top), (y_right, x_right), (y_bottom, x_bottom), (y_left, x_left)]
  511. def heatmaps_to_circle_points(maps, rois,num_points=2):
  512. point_preds = torch.zeros((len(rois), 4, 2), dtype=torch.float32, device=maps.device)
  513. point_end_scores = torch.zeros((len(rois),4, 1), dtype=torch.float32, device=maps.device)
  514. print(f'rois in heatmaps_to_circle_points:{type(rois), rois.shape}') # <class 'torch.Tensor'>
  515. print(f'heatmaps_to_lines:{maps.shape}')
  516. point_maps=maps[:,0]
  517. print(f'point_map:{point_maps.shape}')
  518. for i in range(len(rois)):
  519. point_roi_map = point_maps[i].unsqueeze(0)
  520. print(f'point_roi_map:{point_roi_map.shape}')
  521. # roi_map_probs = scores_to_probs(roi_map.copy())
  522. # w = point_roi_map.shape[2]
  523. # flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
  524. # print(f'non_maximum_suppression :{non_maximum_suppression(point_roi_map).shape}')
  525. # point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
  526. # print(f'point index:{point_index}')
  527. # point_x =point_index % w
  528. # point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
  529. # print(f'point_x:{point_x}, point_y:{point_y}')
  530. # point_preds[i, :, 0] = point_x
  531. # point_preds[i, :, 1] = point_y
  532. roi1=rois[i]
  533. result_points = find_max_heat_point_in_edge_centers(non_maximum_suppression(point_roi_map), roi1)
  534. point_preds[i, :]=torch.tensor(result_points)
  535. point_x = [point[0] for point in result_points]
  536. point_y = [point[1] for point in result_points]
  537. point_end_scores[i, :,0] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
  538. return point_preds,point_end_scores
  539. def heatmaps_to_lines(maps, rois):
  540. line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
  541. line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  542. line_maps=maps[:,1]
  543. # line_maps = maps.squeeze(1)
  544. for i in range(len(rois)):
  545. line_roi_map = line_maps[i].unsqueeze(0)
  546. print(f'line_roi_map:{line_roi_map.shape}')
  547. # roi_map_probs = scores_to_probs(roi_map.copy())
  548. w = line_roi_map.shape[1]
  549. flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1)
  550. line_score, line_index = torch.topk(flatten_line_roi_map, k=2)
  551. print(f'line index:{line_index}')
  552. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  553. pos = line_index
  554. line_x = pos % w
  555. line_y = torch.div(pos - line_x, w, rounding_mode="floor")
  556. line_preds[i, 0, :] = line_x
  557. line_preds[i, 1, :] = line_y
  558. line_preds[i, 2, :] = 1
  559. line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x]
  560. return line_preds.permute(0, 2, 1), line_end_scores
  561. # 显示热图的函数
  562. def show_heatmap(heatmap, title="Heatmap"):
  563. """
  564. 使用 matplotlib 显示热图。
  565. Args:
  566. heatmap (Tensor): 要显示的热图张量
  567. title (str): 图表标题
  568. """
  569. # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组
  570. if heatmap.is_cuda:
  571. heatmap = heatmap.cpu().numpy()
  572. else:
  573. heatmap = heatmap.numpy()
  574. plt.imshow(heatmap, cmap='hot', interpolation='nearest')
  575. plt.colorbar()
  576. plt.title(title)
  577. plt.show()
  578. def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
  579. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  580. N, K, H, W = line_logits.shape
  581. len_proposals = len(proposals)
  582. print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals},line_matched_idxs:{line_matched_idxs}')
  583. if H != W:
  584. raise ValueError(
  585. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  586. )
  587. discretization_size = H
  588. heatmaps = []
  589. gs_heatmaps = []
  590. valid = []
  591. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
  592. print(f'line_proposals_per_image:{proposals_per_image.shape}')
  593. print(f'gt_lines:{gt_lines}')
  594. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  595. kp = gt_kp_in_image[midx]
  596. gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
  597. gs_heatmaps.append(gs_heatmaps_per_img)
  598. # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
  599. # heatmaps.append(heatmaps_per_image.view(-1))
  600. # valid.append(valid_per_image.view(-1))
  601. # line_targets = torch.cat(heatmaps, dim=0)
  602. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  603. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  604. # print(f'line_targets:{line_targets.shape},{line_targets}')
  605. # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  606. # valid = torch.where(valid)[0]
  607. # print(f' line_targets[valid]:{line_targets[valid]}')
  608. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  609. # accept empty tensors, so handle it sepaartely
  610. # if line_targets.numel() == 0 or len(valid) == 0:
  611. # return line_logits.sum() * 0
  612. # line_logits = line_logits.view(N * K, H * W)
  613. # print(f'line_logits[valid]:{line_logits[valid].shape}')
  614. print(f'loss1 line_logits:{line_logits.shape}')
  615. line_logits = line_logits[:,1,:,:]
  616. # line_logits = line_logits.squeeze(1)
  617. print(f'loss2 line_logits:{line_logits.shape}')
  618. # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
  619. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  620. return line_loss
  621. def compute_ins_loss(feature_logits, proposals, gt_, pos_matched_idxs):
  622. print(f'compute_arc_loss:{feature_logits.shape}')
  623. N, K, H, W = feature_logits.shape
  624. len_proposals = len(proposals)
  625. empty_count = 0
  626. non_empty_count = 0
  627. for prop in proposals:
  628. if prop.shape[0] == 0:
  629. empty_count += 1
  630. else:
  631. non_empty_count += 1
  632. print(f"Empty proposals count: {empty_count}")
  633. print(f"Non-empty proposals count: {non_empty_count}")
  634. print(f'starte to compute_point_loss')
  635. print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}')
  636. if H != W:
  637. raise ValueError(
  638. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  639. )
  640. discretization_size = H
  641. gs_heatmaps = []
  642. # print(f'point_matched_idxs:{point_matched_idxs}')
  643. print(f'gt_masks:{gt_[0].shape}')
  644. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs):
  645. # [
  646. # (Tensor(38, 4), Tensor(1, 57, 2), Tensor(38, 1)),
  647. # (Tensor(65, 4), Tensor(1, 74, 2), Tensor(65, 1))
  648. # ]
  649. print(f'proposals_per_image:{proposals_per_image.shape}')
  650. kp = gt_kp_in_image[midx]
  651. t_h, t_w = kp.shape[-2:]
  652. print(f't_h:{t_h}, t_w:{t_w}')
  653. print(f'gt_kp_in_image:{gt_kp_in_image.shape}')
  654. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  655. gs_heatmaps_per_img = align_masks(kp, proposals_per_image, discretization_size)
  656. gs_heatmaps.append(gs_heatmaps_per_img)
  657. if len(gs_heatmaps)>0:
  658. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  659. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}')
  660. line_logits = feature_logits.squeeze(1)
  661. print(f'ins shape:{line_logits.shape}')
  662. # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
  663. # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  664. line_loss=combined_loss(line_logits, gs_heatmaps)
  665. else:
  666. line_loss=100
  667. print("d")
  668. return line_loss
  669. def align_masks(keypoints, rois, heatmap_size):
  670. print(f'rois:{rois.shape}')
  671. print(f'heatmap_size:{heatmap_size}')
  672. print(f'keypoints.shape:{keypoints.shape}')
  673. # batch_size, num_keypoints, _ = keypoints.shape
  674. t_h, t_w = keypoints.shape[-2:]
  675. scale=heatmap_size/t_w
  676. print(f'scale:{scale}')
  677. x = keypoints[..., 0]*scale
  678. y = keypoints[..., 1]*scale
  679. x = x.unsqueeze(1)
  680. y = y.unsqueeze(1)
  681. num_points=x.shape[2]
  682. print(f'num_points:{num_points}')
  683. mask_4d = keypoints.unsqueeze(1).float()
  684. resized_mask = F.interpolate(
  685. mask_4d,
  686. size = (heatmap_size, heatmap_size),
  687. mode = 'bilinear',
  688. align_corners = False
  689. ).squeeze(1) # [B,heatmap_size,heatmap_size]
  690. # plt.imshow(resized_mask[0].cpu())
  691. # plt.show()
  692. print(f'resized_mask:{resized_mask.shape}')
  693. return resized_mask
  694. def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
  695. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  696. N, K, H, W = line_logits.shape
  697. len_proposals = len(proposals)
  698. empty_count = 0
  699. non_empty_count = 0
  700. for prop in proposals:
  701. if prop.shape[0] == 0:
  702. empty_count += 1
  703. else:
  704. non_empty_count += 1
  705. print(f"Empty proposals count: {empty_count}")
  706. print(f"Non-empty proposals count: {non_empty_count}")
  707. print(f'starte to compute_point_loss')
  708. print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
  709. if H != W:
  710. raise ValueError(
  711. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  712. )
  713. discretization_size = H
  714. gs_heatmaps = []
  715. # print(f'point_matched_idxs:{point_matched_idxs}')
  716. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs):
  717. print(f'proposals_per_image:{proposals_per_image.shape}')
  718. kp = gt_kp_in_image[midx]
  719. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  720. gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size)
  721. gs_heatmaps.append(gs_heatmaps_per_img)
  722. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  723. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  724. line_logits = line_logits[:,0]
  725. print(f'single_point_logits:{line_logits.shape}')
  726. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  727. return line_loss
  728. def compute_circle_loss(circle_logits, proposals, gt_circles, circle_matched_idxs):
  729. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  730. N, K, H, W = circle_logits.shape
  731. len_proposals = len(proposals)
  732. empty_count = 0
  733. non_empty_count = 0
  734. for prop in proposals:
  735. if prop.shape[0] == 0:
  736. empty_count += 1
  737. else:
  738. non_empty_count += 1
  739. print(f"Empty proposals count: {empty_count}")
  740. print(f"Non-empty proposals count: {non_empty_count}")
  741. print(f'starte to compute_circle_loss')
  742. print(f'compute_circle_loss circle_logits.shape:{circle_logits.shape},len_proposals:{len_proposals}')
  743. if H != W:
  744. raise ValueError(
  745. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  746. )
  747. discretization_size = H
  748. gs_heatmaps = []
  749. # print(f'point_matched_idxs:{point_matched_idxs}')
  750. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_circles, circle_matched_idxs):
  751. print(f'proposals_per_image:{proposals_per_image.shape}')
  752. kp = gt_kp_in_image[midx]
  753. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  754. gs_heatmaps_per_img = points_to_heatmap(kp, proposals_per_image,num_points=4, heatmap_size=discretization_size)
  755. gs_heatmaps.append(gs_heatmaps_per_img)
  756. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  757. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{circle_logits.squeeze(1).shape}')
  758. circle_logits = circle_logits[:, 0]
  759. print(f'circle_logits:{circle_logits.shape}')
  760. circle_loss = F.cross_entropy(circle_logits, gs_heatmaps)
  761. return circle_loss
  762. def lines_to_boxes(lines, img_size=511):
  763. """
  764. 输入:
  765. lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
  766. img_size: int,图像尺寸,用于 clamp 边界
  767. 输出:
  768. boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
  769. """
  770. # 提取所有线段的两个端点
  771. p1 = lines[:, 0] # (N, 2)
  772. p2 = lines[:, 1] # (N, 2)
  773. # 每条线段的 x 和 y 坐标
  774. x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1) # (N, 2)
  775. y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1) # (N, 2)
  776. # 计算包围盒边界
  777. x_min = x_coords.min(dim=1).values
  778. y_min = y_coords.min(dim=1).values
  779. x_max = x_coords.max(dim=1).values
  780. y_max = y_coords.max(dim=1).values
  781. # 扩展边界并限制在图像范围内
  782. x_min = (x_min - 1).clamp(min=0, max=img_size)
  783. y_min = (y_min - 1).clamp(min=0, max=img_size)
  784. x_max = (x_max + 1).clamp(min=0, max=img_size)
  785. y_max = (y_max + 1).clamp(min=0, max=img_size)
  786. # 合成包围盒
  787. boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1) # (N, 4)
  788. return boxes
  789. def box_iou_pairwise(box1, box2):
  790. """
  791. 输入:
  792. box1: shape (N, 4)
  793. box2: shape (M, 4)
  794. 输出:
  795. ious: shape (min(N, M), ), 只计算 i = j 的配对
  796. """
  797. N = min(len(box1), len(box2))
  798. lt = torch.max(box1[:N, :2], box2[:N, :2]) # 左上角
  799. rb = torch.min(box1[:N, 2:], box2[:N, 2:]) # 右下角
  800. wh = (rb - lt).clamp(min=0) # 宽高
  801. inter_area = wh[:, 0] * wh[:, 1] # 交集面积
  802. area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
  803. area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
  804. union_area = area1 + area2 - inter_area
  805. ious = inter_area / (union_area + 1e-6)
  806. return ious
  807. def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0):
  808. """
  809. Args:
  810. x: [N,1,H,W] 热力图
  811. boxes: [N,4] 框坐标
  812. gt_lines: [N,2,3] GT线段(含可见性)
  813. matched_idx: 匹配 index
  814. img_size: 图像尺寸
  815. alpha: IoU 损失权重
  816. beta: 长度损失权重
  817. gamma: 方向角度损失权重
  818. """
  819. losses = []
  820. boxes_per_image = [box.size(0) for box in boxes]
  821. x2 = x.split(boxes_per_image, dim=0)
  822. for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
  823. p_prob, _ = heatmaps_to_lines(xx, bb)
  824. pred_lines = p_prob
  825. gt_line_points = gt_line[mid]
  826. if len(pred_lines) == 0 or len(gt_line_points) == 0:
  827. continue
  828. # IoU 损失
  829. pred_boxes = lines_to_boxes(pred_lines, img_size)
  830. gt_boxes = lines_to_boxes(gt_line_points, img_size)
  831. ious = box_iou_pairwise(pred_boxes, gt_boxes)
  832. iou_loss = 1.0 - ious # [N]
  833. # 长度损失
  834. pred_len = line_length(pred_lines)
  835. gt_len = line_length(gt_line_points)
  836. length_diff = F.l1_loss(pred_len, gt_len, reduction='none') # [N]
  837. # 方向角度损失
  838. pred_dir = line_direction(pred_lines)
  839. gt_dir = line_direction(gt_line_points)
  840. ang_loss = angle_loss_cosine(pred_dir, gt_dir) # [N]
  841. # 归一化每一项损失
  842. norm_iou = normalize_tensor(iou_loss)
  843. norm_len = normalize_tensor(length_diff)
  844. norm_ang = normalize_tensor(ang_loss)
  845. total = alpha * norm_iou + beta * norm_len + gamma * norm_ang
  846. losses.append(total)
  847. if not losses:
  848. return None
  849. return torch.mean(torch.cat(losses))
  850. def point_inference(x, point_boxes):
  851. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  852. points_probs = []
  853. points_scores = []
  854. boxes_per_image = [box.size(0) for box in point_boxes]
  855. x2 = x.split(boxes_per_image, dim=0)
  856. for xx, bb in zip(x2, point_boxes):
  857. point_prob,point_scores = heatmaps_to_points(xx, bb,num_points=1)
  858. points_probs.append(point_prob.unsqueeze(1))
  859. points_scores.append(point_scores)
  860. return points_probs,points_scores
  861. def circle_inference(x, point_boxes):
  862. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  863. points_probs = []
  864. points_scores = []
  865. boxes_per_image = [box.size(0) for box in point_boxes]
  866. x2 = x.split(boxes_per_image, dim=0)
  867. for xx, bb in zip(x2, point_boxes):
  868. point_prob,point_scores = heatmaps_to_circle_points(xx, bb,num_points=4)
  869. points_probs.append(point_prob.unsqueeze(1))
  870. points_scores.append(point_scores)
  871. return points_probs,points_scores
  872. def line_inference(x, line_boxes):
  873. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  874. lines_probs = []
  875. lines_scores = []
  876. boxes_per_image = [box.size(0) for box in line_boxes]
  877. x2 = x.split(boxes_per_image, dim=0)
  878. # x2:tuple 2 x2[0]:[1,3,1024,1024]
  879. # line_box: list:2 [1,4] [1.4] fasterrcnn kuang
  880. for xx, bb in zip(x2, line_boxes):
  881. line_prob, line_scores, = heatmaps_to_lines(xx, bb)
  882. lines_probs.append(line_prob)
  883. lines_scores.append(line_scores)
  884. return lines_probs, lines_scores
  885. def ins_inference(x, arc_boxes, th):
  886. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  887. points_probs = []
  888. points_scores = []
  889. print(f'arc_boxes:{len(arc_boxes)}')
  890. boxes_per_image = [box.size(0) for box in arc_boxes]
  891. print(f'arc boxes_per_image:{boxes_per_image}')
  892. x2 = x.split(boxes_per_image, dim=0)
  893. for xx, bb in zip(x2, arc_boxes):
  894. point_prob,point_scores = heatmaps_to_arc(xx, bb)
  895. points_probs.append(point_prob.unsqueeze(1))
  896. points_scores.append(point_scores)
  897. points_probs_tensor=torch.cat(points_probs)
  898. print(f'points_probs shape:{points_probs_tensor.shape}')
  899. feature_logits = x
  900. batch_size = feature_logits.shape[0]
  901. num_proposals = len(arc_boxes[0])
  902. results = [[torch.empty(0, 2) for _ in range(num_proposals)] for _ in range(batch_size)]
  903. proposals_list = arc_boxes[0] # [[tensor(...)]]
  904. for proposal_idx, proposal in enumerate(proposals_list):
  905. coords = proposal.tolist()
  906. x1, y1, x2, y2 = map(int, coords)
  907. x1 = max(0, x1)
  908. y1 = max(0, y1)
  909. x2 = min(feature_logits.shape[3], x2)
  910. y2 = min(feature_logits.shape[2], y2)
  911. for batch_idx in range(batch_size):
  912. region = feature_logits[batch_idx, :, y1:y2, x1:x2]
  913. mask = region > th
  914. coords = torch.nonzero(mask)
  915. if coords.numel() > 0:
  916. # 取 (y, x),然后转换为全局坐标 (x, y)
  917. local_coords = coords[:, [2, 1]] # (x, y)
  918. local_coords[:, 0] += x1
  919. local_coords[:, 1] += y1
  920. results[batch_idx][proposal_idx] = local_coords
  921. print(f're:{results}')
  922. return points_probs,points_scores,results
  923. import torch.nn.functional as F
  924. def heatmaps_to_arc(maps, rois, threshold=0.5, output_size=(128, 128)):
  925. """
  926. Args:
  927. maps: [N, 3, H, W] - full heatmaps
  928. rois: [N, 4] - bounding boxes
  929. threshold: float - binarization threshold
  930. output_size: resized size for uniform NMS
  931. Returns:
  932. masks: [N, 1, H, W] - binary ins aligned with input map
  933. scores: [N, 1] - count of non-zero pixels in each ins
  934. """
  935. N, _, H, W = maps.shape
  936. masks = torch.zeros((N, 1, H, W), dtype=torch.float32, device=maps.device)
  937. scores = torch.zeros((N, 1), dtype=torch.float32, device=maps.device)
  938. point_maps = maps[:, 0] # È¡µÚÒ»¸öͨµÀ [N, H, W]
  939. print(f"==> heatmaps_to_arc: maps.shape = {maps.shape}, rois.shape = {rois.shape}")
  940. for i in range(N):
  941. x1, y1, x2, y2 = rois[i].long()
  942. x1 = x1.clamp(0, W - 1)
  943. x2 = x2.clamp(0, W - 1)
  944. y1 = y1.clamp(0, H - 1)
  945. y2 = y2.clamp(0, H - 1)
  946. print(f"[{i}] roi: ({x1.item()}, {y1.item()}, {x2.item()}, {y2.item()})")
  947. if x2 <= x1 or y2 <= y1:
  948. print(f" Skipped invalid ROI at index {i}")
  949. continue
  950. roi_map = point_maps[i, y1:y2, x1:x2] # [h, w]
  951. print(f" roi_map.shape: {roi_map.shape}")
  952. if roi_map.numel() == 0:
  953. print(f" Skipped empty ROI at index {i}")
  954. continue
  955. # resize to uniform size
  956. roi_map_resized = F.interpolate(
  957. roi_map.unsqueeze(0).unsqueeze(0),
  958. size=output_size,
  959. mode='bilinear',
  960. align_corners=False
  961. ) # [1, 1, H, W]
  962. print(f" roi_map_resized.shape: {roi_map_resized.shape}")
  963. # NMS + threshold
  964. # nms_roi = non_maximum_suppression(roi_map_resized) # shape: [1, H, W]
  965. nms_roi = torch.sigmoid(roi_map_resized)
  966. bin_mask = (nms_roi >= threshold).float() # shape: [1, H, W]
  967. print(f" bin_mask.sum(): {bin_mask.sum().item()}")
  968. # resize back to original roi size
  969. h = int((y2 - y1).item())
  970. w = int((x2 - x1).item())
  971. # È·±£ bin_mask ÊÇ [1, 128, 128]
  972. assert bin_mask.dim() == 4, f"Expected 3D tensor [1, H, W], got {bin_mask.shape}"
  973. bin_mask_original_size = F.interpolate(
  974. bin_mask, # [1, 1, 128, 128]
  975. size=(h, w),
  976. mode='bilinear',
  977. align_corners=False
  978. )[0] # [1, h, w]
  979. bin_mask_original_size = bin_mask_original_size[0] # [h, w]
  980. masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size
  981. scores[i] = bin_mask_original_size.sum()
  982. #
  983. # bin_mask_original_size = F.interpolate(
  984. # # bin_mask.unsqueeze(0), # ? [1, 1, 128, 128]
  985. # bin_mask, # ? [1, 1, 128, 128]
  986. # size=(h, w),
  987. # mode='bilinear',
  988. # align_corners=False
  989. # )[0] # ? [1, h, w]
  990. #
  991. # masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size.squeeze()
  992. # scores[i] = bin_mask_original_size.sum()
  993. # plt.figure(figsize=(6, 6))
  994. # plt.imshow(masks[i, 0].cpu().numpy(), cmap='gray')
  995. # plt.title(f"Mask {i}, score={scores[i].item():.1f}")
  996. # plt.axis('off')
  997. # plt.show()
  998. print(f" bin_mask_original_size.shape: {bin_mask_original_size.shape}, sum: {scores[i].item()}")
  999. print(f"==> Done. Total valid masks: {(scores > 0).sum().item()} / {N}")
  1000. return masks, scores