head_losses.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322
  1. import torch
  2. from matplotlib import pyplot as plt
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from torch.cuda import device
  6. from utils.data_process.mask.show_mask import save_full_mask
  7. class DiceLoss(nn.Module):
  8. def __init__(self, smooth=1.):
  9. super(DiceLoss, self).__init__()
  10. self.smooth = smooth
  11. def forward(self, logits, targets):
  12. probs = torch.sigmoid(logits)
  13. probs = probs.view(-1)
  14. targets = targets.view(-1).float()
  15. intersection = (probs * targets).sum()
  16. dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
  17. return 1. - dice
  18. bce_loss = nn.BCEWithLogitsLoss()
  19. dice_loss = DiceLoss()
  20. def combined_loss(preds, targets, alpha=0.5):
  21. bce = bce_loss(preds, targets)
  22. d = dice_loss(preds, targets)
  23. return alpha * bce + (1 - alpha) * d
  24. def features_align(features, proposals, img_size):
  25. print(f'features_align features:{features.shape},proposals:{len(proposals)}')
  26. align_feat_list = []
  27. for feat, proposals_per_img in zip(features, proposals):
  28. print(f'feature_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
  29. if proposals_per_img.shape[0]>0:
  30. feat = feat.unsqueeze(0)
  31. for proposal in proposals_per_img:
  32. align_feat = torch.zeros_like(feat)
  33. # print(f'align_feat:{align_feat.shape}')
  34. x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
  35. # 将每个proposal框内的部分赋值到align_feats对应位置
  36. align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
  37. align_feat_list.append(align_feat)
  38. # print(f'align_feat_list:{align_feat_list}')
  39. if len(align_feat_list) > 0:
  40. feats_tensor = torch.cat(align_feat_list)
  41. print(f'align features :{feats_tensor.shape}')
  42. else:
  43. feats_tensor = None
  44. return feats_tensor
  45. def normalize_tensor(t):
  46. return (t - t.min()) / (t.max() - t.min() + 1e-6)
  47. def line_length(lines):
  48. """
  49. 计算每条线段的长度
  50. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  51. 返回: [N]
  52. """
  53. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  54. def line_direction(lines):
  55. """
  56. 计算每条线段的单位方向向量
  57. lines: [N, 2, 2]
  58. 返回: [N, 2] 单位方向向量
  59. """
  60. vec = lines[:, 1] - lines[:, 0]
  61. return F.normalize(vec, dim=-1)
  62. def angle_loss_cosine(pred_dir, gt_dir):
  63. """
  64. 使用 cosine similarity 计算方向差异
  65. pred_dir: [N, 2]
  66. gt_dir: [N, 2]
  67. 返回: [N]
  68. """
  69. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  70. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  71. def line_length(lines):
  72. """
  73. 计算每条线段的长度
  74. lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
  75. 返回: [N]
  76. """
  77. return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
  78. def line_direction(lines):
  79. """
  80. 计算每条线段的单位方向向量
  81. lines: [N, 2, 2]
  82. 返回: [N, 2] 单位方向向量
  83. """
  84. vec = lines[:, 1] - lines[:, 0]
  85. return F.normalize(vec, dim=-1)
  86. def angle_loss_cosine(pred_dir, gt_dir):
  87. """
  88. 使用 cosine similarity 计算方向差异
  89. pred_dir: [N, 2]
  90. gt_dir: [N, 2]
  91. 返回: [N]
  92. """
  93. cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
  94. return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
  95. def single_point_to_heatmap(keypoints, rois, heatmap_size):
  96. # type: (Tensor, Tensor, int) -> Tensor
  97. print(f'rois:{rois.shape}')
  98. print(f'heatmap_size:{heatmap_size}')
  99. print(f'keypoints.shape:{keypoints.shape}')
  100. # batch_size, num_keypoints, _ = keypoints.shape
  101. x = keypoints[..., 0].unsqueeze(1)
  102. y = keypoints[..., 1].unsqueeze(1)
  103. gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
  104. # show_heatmap(gs[0],'target')
  105. all_roi_heatmap = []
  106. for roi, heatmap in zip(rois, gs):
  107. # show_heatmap(heatmap, 'target')
  108. # print(f'heatmap:{heatmap.shape}')
  109. heatmap = heatmap.unsqueeze(0)
  110. x1, y1, x2, y2 = map(int, roi)
  111. roi_heatmap = torch.zeros_like(heatmap)
  112. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  113. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  114. all_roi_heatmap.append(roi_heatmap)
  115. all_roi_heatmap = torch.cat(all_roi_heatmap)
  116. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  117. return all_roi_heatmap
  118. def arc_inference1(arc_equation, x, arc_boxes, th):
  119. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  120. points_probs = []
  121. points_scores = []
  122. print(f'arc_boxes:{len(arc_boxes)}')
  123. boxes_per_image = [box.size(0) for box in arc_boxes]
  124. print(f'arc boxes_per_image:{boxes_per_image}')
  125. x2 = x.split(boxes_per_image, dim=0)
  126. arc7 = arc_equation.split(boxes_per_image, dim=0)
  127. # print(f'arc7:{arc7}')
  128. for xx, bb in zip(x2, arc_boxes):
  129. point_prob, point_scores = heatmaps_to_arc(xx, bb)
  130. points_probs.append(point_prob.unsqueeze(1))
  131. points_scores.append(point_scores)
  132. return arc7, points_scores
  133. def points_to_heatmap(keypoints, rois,num_points=2, heatmap_size=(512,512)):
  134. # type: (Tensor, Tensor, int) -> Tensor
  135. print(f'rois:{rois.shape}')
  136. print(f'heatmap_size:{heatmap_size}')
  137. print(f'keypoints.shape:{keypoints.shape}')
  138. # batch_size, num_keypoints, _ = keypoints.shape
  139. x = keypoints[..., 0].unsqueeze(1)
  140. y = keypoints[..., 1].unsqueeze(1)
  141. gs = generate_gaussian_heatmaps(x, y,num_points=num_points, heatmap_size=heatmap_size, sigma=2.0)
  142. # show_heatmap(gs[0],'target')
  143. all_roi_heatmap = []
  144. for roi, heatmap in zip(rois, gs):
  145. # show_heatmap(heatmap, 'target')
  146. # print(f'heatmap:{heatmap.shape}')
  147. heatmap = heatmap.unsqueeze(0)
  148. x1, y1, x2, y2 = map(int, roi)
  149. roi_heatmap = torch.zeros_like(heatmap)
  150. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  151. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  152. all_roi_heatmap.append(roi_heatmap)
  153. all_roi_heatmap = torch.cat(all_roi_heatmap)
  154. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  155. return all_roi_heatmap
  156. def line_points_to_heatmap(keypoints, rois, heatmap_size):
  157. # type: (Tensor, Tensor, int) -> Tensor
  158. print(f'rois:{rois.shape}')
  159. print(f'heatmap_size:{heatmap_size}')
  160. print(f'keypoints.shape:{keypoints.shape}')
  161. # batch_size, num_keypoints, _ = keypoints.shape
  162. x = keypoints[..., 0]
  163. y = keypoints[..., 1]
  164. gs = generate_gaussian_heatmaps(x, y, heatmap_size,num_points=2, sigma=1.0)
  165. # show_heatmap(gs[0],'target')
  166. all_roi_heatmap = []
  167. for roi, heatmap in zip(rois, gs):
  168. # print(f'heatmap:{heatmap.shape}')
  169. # show_heatmap(heatmap,'target')
  170. heatmap = heatmap.unsqueeze(0)
  171. x1, y1, x2, y2 = map(int, roi)
  172. roi_heatmap = torch.zeros_like(heatmap)
  173. roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
  174. # show_heatmap(roi_heatmap[0],'roi_heatmap')
  175. all_roi_heatmap.append(roi_heatmap)
  176. if len(all_roi_heatmap) > 0:
  177. all_roi_heatmap = torch.cat(all_roi_heatmap)
  178. print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
  179. else:
  180. all_roi_heatmap = None
  181. return all_roi_heatmap
  182. """
  183. 修改适配的原结构的点 转热图,适用于带roi_pool版本的
  184. """
  185. def line_points_to_heatmap_(keypoints, rois, heatmap_size):
  186. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  187. print(f'rois:{rois.shape}')
  188. print(f'heatmap_size:{heatmap_size}')
  189. offset_x = rois[:, 0]
  190. offset_y = rois[:, 1]
  191. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  192. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  193. offset_x = offset_x[:, None]
  194. offset_y = offset_y[:, None]
  195. scale_x = scale_x[:, None]
  196. scale_y = scale_y[:, None]
  197. print(f'keypoints.shape:{keypoints.shape}')
  198. # batch_size, num_keypoints, _ = keypoints.shape
  199. x = keypoints[..., 0]
  200. y = keypoints[..., 1]
  201. # gs=generate_gaussian_heatmaps(x,y,512,1.0)
  202. # print(f'gs_heatmap shape:{gs.shape}')
  203. #
  204. # show_heatmap(gs[0],'target')
  205. x_boundary_inds = x == rois[:, 2][:, None]
  206. y_boundary_inds = y == rois[:, 3][:, None]
  207. x = (x - offset_x) * scale_x
  208. x = x.floor().long()
  209. y = (y - offset_y) * scale_y
  210. y = y.floor().long()
  211. x[x_boundary_inds] = heatmap_size - 1
  212. y[y_boundary_inds] = heatmap_size - 1
  213. # print(f'heatmaps x:{x}')
  214. # print(f'heatmaps y:{y}')
  215. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  216. vis = keypoints[..., 2] > 0
  217. valid = (valid_loc & vis).long()
  218. gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
  219. # show_heatmap(gs_heatmap[0], 'feature')
  220. # print(f'gs_heatmap:{gs_heatmap.shape}')
  221. #
  222. # lin_ind = y * heatmap_size + x
  223. # print(f'lin_ind:{lin_ind.shape}')
  224. # heatmaps = lin_ind * valid
  225. return gs_heatmap
  226. def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  227. """
  228. 为一组点生成并合并高斯热图。
  229. Args:
  230. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  231. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  232. heatmap_size (int): 热图大小 H=W
  233. sigma (float): 高斯核标准差
  234. device (str): 设备类型 ('cpu' or 'cuda')
  235. Returns:
  236. Tensor: 形状为 (H, W) 的合并后的热图
  237. """
  238. assert xs.shape == ys.shape, "x and y must have the same shape"
  239. print(f'xs:{xs.shape}')
  240. xs=xs.squeeze(1)
  241. ys = ys.squeeze(1)
  242. print(f'xs1:{xs.shape}')
  243. N = xs.shape[0]
  244. print(f'N:{N},num_points:{num_points}')
  245. # 创建网格
  246. grid_y, grid_x = torch.meshgrid(
  247. torch.arange(heatmap_size, device=device),
  248. torch.arange(heatmap_size, device=device),
  249. indexing='ij'
  250. )
  251. # print(f'heatmap_size:{heatmap_size}')
  252. # 初始化输出热图
  253. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  254. for i in range(N):
  255. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  256. for j in range(num_points):
  257. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  258. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  259. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  260. # 计算距离平方
  261. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  262. # 计算高斯分布
  263. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  264. heatmap+=heatmap1
  265. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  266. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  267. #
  268. # # 计算距离平方
  269. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  270. #
  271. # # 计算高斯分布
  272. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  273. #
  274. # heatmap = heatmap1 + heatmap2
  275. # 将当前热图累加到结果中
  276. combined_heatmap[i] = heatmap
  277. return combined_heatmap
  278. def generate_mask_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
  279. """
  280. 为一组点生成并合并高斯热图。
  281. Args:
  282. xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
  283. ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
  284. heatmap_size (int): 热图大小 H=W
  285. sigma (float): 高斯核标准差
  286. device (str): 设备类型 ('cpu' or 'cuda')
  287. Returns:
  288. Tensor: 形状为 (H, W) 的合并后的热图
  289. """
  290. assert xs.shape == ys.shape, "x and y must have the same shape"
  291. print(f'xs:{xs.shape}')
  292. xs=xs.squeeze(1)
  293. ys = ys.squeeze(1)
  294. print(f'xs1:{xs.shape}')
  295. N = xs.shape[0]
  296. print(f'N:{N},num_points:{num_points}')
  297. # 创建网格
  298. grid_y, grid_x = torch.meshgrid(
  299. torch.arange(heatmap_size, device=device),
  300. torch.arange(heatmap_size, device=device),
  301. indexing='ij'
  302. )
  303. # print(f'heatmap_size:{heatmap_size}')
  304. # 初始化输出热图
  305. combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
  306. for i in range(N):
  307. heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
  308. for j in range(num_points):
  309. mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
  310. mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
  311. # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
  312. # 计算距离平方
  313. dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
  314. # 计算高斯分布
  315. heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
  316. heatmap+=heatmap1
  317. # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
  318. # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
  319. #
  320. # # 计算距离平方
  321. # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
  322. #
  323. # # 计算高斯分布
  324. # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
  325. #
  326. # heatmap = heatmap1 + heatmap2
  327. # 将当前热图累加到结果中
  328. combined_heatmap[i] = heatmap
  329. return combined_heatmap
  330. def non_maximum_suppression(a):
  331. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  332. mask = (a == ap).float().clamp(min=0.0)
  333. return a * mask
  334. def heatmaps_to_points(maps, rois,num_points=2):
  335. point_preds = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  336. point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device)
  337. print(f'heatmaps_to_lines:{maps.shape}')
  338. point_maps=maps[:,0]
  339. print(f'point_map:{point_maps.shape}')
  340. for i in range(len(rois)):
  341. point_roi_map = point_maps[i].unsqueeze(0)
  342. print(f'point_roi_map:{point_roi_map.shape}')
  343. # roi_map_probs = scores_to_probs(roi_map.copy())
  344. w = point_roi_map.shape[2]
  345. flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
  346. point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
  347. print(f'point index:{point_index}')
  348. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  349. point_x =point_index % w
  350. point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
  351. point_preds[i, 0,] = point_x
  352. point_preds[i, 1,] = point_y
  353. point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
  354. return point_preds,point_end_scores
  355. # 分4块
  356. def find_max_heat_point_in_each_part(feature_map, box):
  357. """
  358. 在给定的特征图上,根据box中心点往上移3,往右移3作为新的中心点,
  359. 并将特征图划分为4个部分,之后在每个部分中找到热度值最大的点。
  360. Args:
  361. feature_map (torch.Tensor): 形状为 [C, H, W] 的特征图
  362. box (torch.Tensor): 形状为 [4] 的边界框 [x_min, y_min, x_max, y_max]
  363. Returns:
  364. list: 每个区域中热度最高的点的位置和其对应的热度值 [(y1, x1, heat1), ..., (y4, x4, heat4)]
  365. """
  366. device = feature_map.device
  367. C, H, W = feature_map.shape
  368. # 计算box的中心点(cx, cy)
  369. cx = (box[0] + box[2]) // 2
  370. cy = (box[1] + box[3]) // 2
  371. # 偏移中心点
  372. new_cx = min(max(cx + 3, 0), W - 1) # 向右移3
  373. new_cy = min(max(cy - 3, 0), H - 1) # 向上移3
  374. # 创建坐标网格
  375. y_coords, x_coords = torch.meshgrid(
  376. torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij'
  377. )
  378. # 划分四个区域
  379. mask_q1 = (y_coords < new_cy) & (x_coords < new_cx) # 左上
  380. mask_q2 = (y_coords < new_cy) & (x_coords >= new_cx) # 右上
  381. mask_q3 = (y_coords >= new_cy) & (x_coords < new_cx) # 左下
  382. mask_q4 = (y_coords >= new_cy) & (x_coords >= new_cx) # 右下
  383. # def process_region(ins):
  384. # region = feature_map[:, :, ins].squeeze()
  385. # if len(region.shape) == 0: # 如果区域为空,则跳过
  386. # return None, None
  387. # # 找到最大热度值的点及其位置
  388. # (y, x), heat_val = non_maximum_suppression(region[0])
  389. # # 将相对坐标转换回全局坐标
  390. # y_global = y + torch.where(ins)[0].min().item()
  391. # x_global = x + torch.where(ins)[1].min().item()
  392. # return (y_global, x_global), heat_val
  393. #
  394. # results = []
  395. # for ins in [mask_q1, mask_q2, mask_q3, mask_q4]:
  396. # point, heat_val = process_region(ins)
  397. # if point is not None:
  398. # # results.append((point[0], point[1], heat_val))
  399. # results.append((point[0], point[1]))
  400. # else:
  401. # results.append(None)
  402. masks = [mask_q1, mask_q2, mask_q3, mask_q4]
  403. results = []
  404. # 假设使用第一个通道作为热力图
  405. heatmap = feature_map[0] # [H, W]
  406. def process_region(mask):
  407. # 应用 ins,只保留该区域
  408. masked_heatmap = heatmap.clone() # 复制以避免修改原数据
  409. masked_heatmap[~mask] = 0 # 非区域置0
  410. def non_maximum_suppression_2d(heatmap, kernel_size=3):
  411. """
  412. 对 2D 热力图做非极大值抑制,保留局部最大值点。
  413. Args:
  414. heatmap (torch.Tensor): [H, W],输入热力图
  415. kernel_size (int): 池化窗口大小,用于比较是否为局部最大值
  416. Returns:
  417. torch.Tensor: 与 heatmap 同形状的 ins,局部最大值位置为 True
  418. """
  419. pad = (kernel_size - 1) // 2
  420. max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
  421. maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
  422. # 局部最大值且值大于0
  423. peaks = (heatmap == maxima) & (heatmap > 0)
  424. return peaks
  425. # 1. 先做 NMS 得到候选局部极大值点
  426. nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3) # [H, W] bool
  427. candidate_peaks = masked_heatmap * nms_mask.float() # 只保留 NMS 后的峰值
  428. # 2. 找出所有候选点中值最大的一个
  429. if candidate_peaks.max() <= 0:
  430. return None
  431. # 找到最大值的位置
  432. max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
  433. y, x = divmod(max_idx.item(), W)
  434. return (x, y) # 返回 (y, x)
  435. for mask in masks:
  436. point = process_region(mask)
  437. results.append(point)
  438. return results
  439. def non_maximum_suppression_2d(heatmap, kernel_size=3):
  440. pad = (kernel_size - 1) // 2
  441. max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
  442. maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
  443. peaks = (heatmap == maxima) & (heatmap > 0)
  444. return peaks
  445. def find_max_heat_point_in_edge_centers(feature_map, box):
  446. device = feature_map.device
  447. C, H, W = feature_map.shape
  448. # ¼ÆËã box ÖÐÐÄ
  449. cx = (box[0] + box[2]) / 2
  450. cy = (box[1] + box[3]) / 2
  451. # ¸ù¾Ý box ¿í¸ß¼ÆËã¾Å¹¬¸ñ·Ö½çÏß
  452. box_width = box[2] - box[0]
  453. box_height = box[3] - box[1]
  454. x_left = cx - box_width / 6
  455. x_right = cx + box_width / 6
  456. y_top = cy - box_height / 6
  457. y_bottom = cy + box_height / 6
  458. # ´´½¨Íø¸ñ
  459. y_coords, x_coords = torch.meshgrid(
  460. torch.arange(H, device=device),
  461. torch.arange(W, device=device),
  462. indexing='ij'
  463. )
  464. # ¶¨ÒåËĸö¡°±ßÖС±ÇøÓòµÄ ins
  465. mask1 = (x_coords < x_left) & (y_coords < y_top)
  466. mask_top_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords < y_top)
  467. mask3 = (x_coords >= x_right) & (y_coords < y_top)
  468. mask_left_middle = (x_coords < x_left) & (y_coords >= y_top) & (y_coords < y_bottom)
  469. mask_right_middle = (x_coords >= x_right) & (y_coords >= y_top) & (y_coords < y_bottom)
  470. mask4 = (x_coords < x_left) & (y_coords >= y_bottom)
  471. mask_bottom_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords >= y_bottom)
  472. mask_right_bottom = (x_coords >= x_right) & (y_coords >= y_bottom)
  473. # masks = [
  474. # # mask1,
  475. # mask_top_middle,
  476. # # mask3,
  477. # mask_left_middle,
  478. # mask_right_middle,
  479. # # mask4,
  480. # mask_bottom_middle,
  481. # mask_right_bottom
  482. # ]
  483. masks = [
  484. mask_top_middle,
  485. mask_right_middle,
  486. mask_bottom_middle,
  487. mask_left_middle
  488. ]
  489. # ʹÓõÚÒ»¸öͨµÀ×÷ΪÈÈÁ¦Í¼
  490. heatmap = feature_map[0] # [H, W]
  491. results = []
  492. for mask in masks:
  493. masked_heatmap = heatmap.clone()
  494. masked_heatmap[~mask] = 0 # ·ÇÄ¿±êÇøÓòÖà 0
  495. # # NMS ÒÖÖÆ
  496. # nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3)
  497. # candidate_peaks = masked_heatmap * nms_mask.float()
  498. #
  499. # if candidate_peaks.max() <= 0:
  500. # results.append(None)
  501. # continue
  502. #
  503. # # ÕÒ×î´óֵλÖÃ
  504. # max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
  505. # y, x = divmod(max_idx.item(), W)
  506. flatten_point_roi_map = masked_heatmap.reshape(1, -1)
  507. point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
  508. point_x =point_index % W
  509. point_y = torch.div(point_index - point_x, W, rounding_mode="floor")
  510. results.append((point_x, point_y))
  511. return results # [(y_top, x_top), (y_right, x_right), (y_bottom, x_bottom), (y_left, x_left)]
  512. def heatmaps_to_circle_points(maps, rois,num_points=2):
  513. point_preds = torch.zeros((len(rois), 4, 2), dtype=torch.float32, device=maps.device)
  514. point_end_scores = torch.zeros((len(rois),4, 1), dtype=torch.float32, device=maps.device)
  515. print(f'rois in heatmaps_to_circle_points:{type(rois), rois.shape}') # <class 'torch.Tensor'>
  516. print(f'heatmaps_to_lines:{maps.shape}')
  517. point_maps=maps[:,0]
  518. print(f'point_map:{point_maps.shape}')
  519. for i in range(len(rois)):
  520. point_roi_map = point_maps[i].unsqueeze(0)
  521. print(f'point_roi_map:{point_roi_map.shape}')
  522. # roi_map_probs = scores_to_probs(roi_map.copy())
  523. # w = point_roi_map.shape[2]
  524. # flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
  525. # print(f'non_maximum_suppression :{non_maximum_suppression(point_roi_map).shape}')
  526. # point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
  527. # print(f'point index:{point_index}')
  528. # point_x =point_index % w
  529. # point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
  530. # print(f'point_x:{point_x}, point_y:{point_y}')
  531. # point_preds[i, :, 0] = point_x
  532. # point_preds[i, :, 1] = point_y
  533. roi1=rois[i]
  534. result_points = find_max_heat_point_in_edge_centers(non_maximum_suppression(point_roi_map), roi1)
  535. point_preds[i, :]=torch.tensor(result_points)
  536. point_x = [point[0] for point in result_points]
  537. point_y = [point[1] for point in result_points]
  538. point_end_scores[i, :,0] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
  539. return point_preds,point_end_scores
  540. def heatmaps_to_lines(maps, rois):
  541. line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
  542. line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
  543. line_maps=maps[:,1]
  544. # line_maps = maps.squeeze(1)
  545. for i in range(len(rois)):
  546. line_roi_map = line_maps[i].unsqueeze(0)
  547. print(f'line_roi_map:{line_roi_map.shape}')
  548. # roi_map_probs = scores_to_probs(roi_map.copy())
  549. w = line_roi_map.shape[1]
  550. flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1)
  551. line_score, line_index = torch.topk(flatten_line_roi_map, k=2)
  552. print(f'line index:{line_index}')
  553. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  554. pos = line_index
  555. line_x = pos % w
  556. line_y = torch.div(pos - line_x, w, rounding_mode="floor")
  557. line_preds[i, 0, :] = line_x
  558. line_preds[i, 1, :] = line_y
  559. line_preds[i, 2, :] = 1
  560. line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x]
  561. return line_preds.permute(0, 2, 1), line_end_scores
  562. # 显示热图的函数
  563. def show_heatmap(heatmap, title="Heatmap"):
  564. """
  565. 使用 matplotlib 显示热图。
  566. Args:
  567. heatmap (Tensor): 要显示的热图张量
  568. title (str): 图表标题
  569. """
  570. # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组
  571. if heatmap.is_cuda:
  572. heatmap = heatmap.cpu().numpy()
  573. else:
  574. heatmap = heatmap.numpy()
  575. plt.imshow(heatmap, cmap='hot', interpolation='nearest')
  576. plt.colorbar()
  577. plt.title(title)
  578. plt.show()
  579. def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
  580. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  581. N, K, H, W = line_logits.shape
  582. len_proposals = len(proposals)
  583. print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals},line_matched_idxs:{line_matched_idxs}')
  584. if H != W:
  585. raise ValueError(
  586. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  587. )
  588. discretization_size = H
  589. heatmaps = []
  590. gs_heatmaps = []
  591. valid = []
  592. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
  593. print(f'line_proposals_per_image:{proposals_per_image.shape}')
  594. print(f'gt_lines:{gt_lines}')
  595. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  596. kp = gt_kp_in_image[midx]
  597. gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
  598. gs_heatmaps.append(gs_heatmaps_per_img)
  599. # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
  600. # heatmaps.append(heatmaps_per_image.view(-1))
  601. # valid.append(valid_per_image.view(-1))
  602. # line_targets = torch.cat(heatmaps, dim=0)
  603. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  604. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  605. # print(f'line_targets:{line_targets.shape},{line_targets}')
  606. # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  607. # valid = torch.where(valid)[0]
  608. # print(f' line_targets[valid]:{line_targets[valid]}')
  609. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  610. # accept empty tensors, so handle it sepaartely
  611. # if line_targets.numel() == 0 or len(valid) == 0:
  612. # return line_logits.sum() * 0
  613. # line_logits = line_logits.view(N * K, H * W)
  614. # print(f'line_logits[valid]:{line_logits[valid].shape}')
  615. print(f'loss1 line_logits:{line_logits.shape}')
  616. line_logits = line_logits[:,1,:,:]
  617. # line_logits = line_logits.squeeze(1)
  618. print(f'loss2 line_logits:{line_logits.shape}')
  619. # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
  620. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  621. return line_loss
  622. def compute_ins_loss(feature_logits, proposals, gt_, pos_matched_idxs):
  623. print(f'compute_arc_loss:{feature_logits.shape}')
  624. N, K, H, W = feature_logits.shape
  625. len_proposals = len(proposals)
  626. empty_count = 0
  627. non_empty_count = 0
  628. for prop in proposals:
  629. if prop.shape[0] == 0:
  630. empty_count += 1
  631. else:
  632. non_empty_count += 1
  633. print(f"Empty proposals count: {empty_count}")
  634. print(f"Non-empty proposals count: {non_empty_count}")
  635. print(f'starte to compute_point_loss')
  636. print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}')
  637. if H != W:
  638. raise ValueError(
  639. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  640. )
  641. discretization_size = H
  642. gs_heatmaps = []
  643. # print(f'point_matched_idxs:{point_matched_idxs}')
  644. print(f'gt_masks:{gt_[0].shape}')
  645. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs):
  646. # [
  647. # (Tensor(38, 4), Tensor(1, 57, 2), Tensor(38, 1)),
  648. # (Tensor(65, 4), Tensor(1, 74, 2), Tensor(65, 1))
  649. # ]
  650. print(f'proposals_per_image:{proposals_per_image.shape}')
  651. kp = gt_kp_in_image[midx]
  652. t_h, t_w = kp.shape[-2:]
  653. print(f't_h:{t_h}, t_w:{t_w}')
  654. print(f'gt_kp_in_image:{gt_kp_in_image.shape}')
  655. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  656. gs_heatmaps_per_img = align_masks(kp, proposals_per_image, discretization_size)
  657. gs_heatmaps.append(gs_heatmaps_per_img)
  658. if len(gs_heatmaps)>0:
  659. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  660. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}')
  661. line_logits = feature_logits.squeeze(1)
  662. print(f'ins shape:{line_logits.shape}')
  663. # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
  664. # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  665. # save_full_mask(line_logits,"line_logits",out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_loss")
  666. # save_full_mask(gs_heatmaps,"gs_heatmaps",out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_loss")
  667. line_loss=combined_loss(line_logits, gs_heatmaps)
  668. else:
  669. line_loss=100
  670. print("d")
  671. return line_loss
  672. def align_masks(keypoints, rois, heatmap_size):
  673. print(f'rois:{rois.shape}')
  674. print(f'heatmap_size:{heatmap_size}')
  675. print(f'keypoints.shape:{keypoints.shape}')
  676. # t_h, t_w = keypoints.shape[-2:]
  677. # scale=heatmap_size/t_w
  678. # print(f'scale:{scale}')
  679. # x = keypoints[..., 0]*scale
  680. # y = keypoints[..., 1]*scale
  681. #
  682. # x = x.unsqueeze(1)
  683. # y = y.unsqueeze(1)
  684. #
  685. # num_points=x.shape[2]
  686. # print(f'num_points:{num_points}')
  687. # plt.imshow(keypoints[0].cpu())
  688. # plt.show()
  689. mask_4d = keypoints.unsqueeze(1).float()
  690. resized_mask = F.interpolate(
  691. mask_4d,
  692. size = (heatmap_size, heatmap_size),
  693. mode = 'bilinear',
  694. align_corners = False
  695. ).squeeze(1) # [B,heatmap_size,heatmap_size]
  696. # plt.imshow(resized_mask[0].cpu())
  697. # plt.show()
  698. print(f'resized_mask:{resized_mask.shape}')
  699. return resized_mask
  700. def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
  701. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  702. N, K, H, W = line_logits.shape
  703. len_proposals = len(proposals)
  704. empty_count = 0
  705. non_empty_count = 0
  706. for prop in proposals:
  707. if prop.shape[0] == 0:
  708. empty_count += 1
  709. else:
  710. non_empty_count += 1
  711. print(f"Empty proposals count: {empty_count}")
  712. print(f"Non-empty proposals count: {non_empty_count}")
  713. print(f'starte to compute_point_loss')
  714. print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
  715. if H != W:
  716. raise ValueError(
  717. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  718. )
  719. discretization_size = H
  720. gs_heatmaps = []
  721. # print(f'point_matched_idxs:{point_matched_idxs}')
  722. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs):
  723. print(f'proposals_per_image:{proposals_per_image.shape}')
  724. kp = gt_kp_in_image[midx]
  725. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  726. gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size)
  727. gs_heatmaps.append(gs_heatmaps_per_img)
  728. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  729. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
  730. line_logits = line_logits[:,0]
  731. print(f'single_point_logits:{line_logits.shape}')
  732. line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  733. return line_loss
  734. def compute_circle_loss(circle_logits, proposals, gt_circles, circle_matched_idxs):
  735. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  736. N, K, H, W = circle_logits.shape
  737. len_proposals = len(proposals)
  738. empty_count = 0
  739. non_empty_count = 0
  740. for prop in proposals:
  741. if prop.shape[0] == 0:
  742. empty_count += 1
  743. else:
  744. non_empty_count += 1
  745. print(f"Empty proposals count: {empty_count}")
  746. print(f"Non-empty proposals count: {non_empty_count}")
  747. print(f'starte to compute_circle_loss')
  748. print(f'compute_circle_loss circle_logits.shape:{circle_logits.shape},len_proposals:{len_proposals}')
  749. if H != W:
  750. raise ValueError(
  751. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  752. )
  753. discretization_size = H
  754. gs_heatmaps = []
  755. # print(f'point_matched_idxs:{point_matched_idxs}')
  756. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_circles, circle_matched_idxs):
  757. print(f'proposals_per_image:{proposals_per_image.shape}')
  758. kp = gt_kp_in_image[midx]
  759. # print(f'gt_kp_in_image:{gt_kp_in_image}')
  760. gs_heatmaps_per_img = points_to_heatmap(kp, proposals_per_image,num_points=4, heatmap_size=discretization_size)
  761. gs_heatmaps.append(gs_heatmaps_per_img)
  762. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  763. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{circle_logits.squeeze(1).shape}')
  764. circle_logits = circle_logits[:, 0]
  765. print(f'circle_logits:{circle_logits.shape}')
  766. circle_loss = F.cross_entropy(circle_logits, gs_heatmaps)
  767. return circle_loss
  768. def lines_to_boxes(lines, img_size=511):
  769. """
  770. 输入:
  771. lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
  772. img_size: int,图像尺寸,用于 clamp 边界
  773. 输出:
  774. boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
  775. """
  776. # 提取所有线段的两个端点
  777. p1 = lines[:, 0] # (N, 2)
  778. p2 = lines[:, 1] # (N, 2)
  779. # 每条线段的 x 和 y 坐标
  780. x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1) # (N, 2)
  781. y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1) # (N, 2)
  782. # 计算包围盒边界
  783. x_min = x_coords.min(dim=1).values
  784. y_min = y_coords.min(dim=1).values
  785. x_max = x_coords.max(dim=1).values
  786. y_max = y_coords.max(dim=1).values
  787. # 扩展边界并限制在图像范围内
  788. x_min = (x_min - 1).clamp(min=0, max=img_size)
  789. y_min = (y_min - 1).clamp(min=0, max=img_size)
  790. x_max = (x_max + 1).clamp(min=0, max=img_size)
  791. y_max = (y_max + 1).clamp(min=0, max=img_size)
  792. # 合成包围盒
  793. boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1) # (N, 4)
  794. return boxes
  795. def box_iou_pairwise(box1, box2):
  796. """
  797. 输入:
  798. box1: shape (N, 4)
  799. box2: shape (M, 4)
  800. 输出:
  801. ious: shape (min(N, M), ), 只计算 i = j 的配对
  802. """
  803. N = min(len(box1), len(box2))
  804. lt = torch.max(box1[:N, :2], box2[:N, :2]) # 左上角
  805. rb = torch.min(box1[:N, 2:], box2[:N, 2:]) # 右下角
  806. wh = (rb - lt).clamp(min=0) # 宽高
  807. inter_area = wh[:, 0] * wh[:, 1] # 交集面积
  808. area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
  809. area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
  810. union_area = area1 + area2 - inter_area
  811. ious = inter_area / (union_area + 1e-6)
  812. return ious
  813. def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0):
  814. """
  815. Args:
  816. x: [N,1,H,W] 热力图
  817. boxes: [N,4] 框坐标
  818. gt_lines: [N,2,3] GT线段(含可见性)
  819. matched_idx: 匹配 index
  820. img_size: 图像尺寸
  821. alpha: IoU 损失权重
  822. beta: 长度损失权重
  823. gamma: 方向角度损失权重
  824. """
  825. losses = []
  826. boxes_per_image = [box.size(0) for box in boxes]
  827. x2 = x.split(boxes_per_image, dim=0)
  828. for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
  829. p_prob, _ = heatmaps_to_lines(xx, bb)
  830. pred_lines = p_prob
  831. gt_line_points = gt_line[mid]
  832. if len(pred_lines) == 0 or len(gt_line_points) == 0:
  833. continue
  834. # IoU 损失
  835. pred_boxes = lines_to_boxes(pred_lines, img_size)
  836. gt_boxes = lines_to_boxes(gt_line_points, img_size)
  837. ious = box_iou_pairwise(pred_boxes, gt_boxes)
  838. iou_loss = 1.0 - ious # [N]
  839. # 长度损失
  840. pred_len = line_length(pred_lines)
  841. gt_len = line_length(gt_line_points)
  842. length_diff = F.l1_loss(pred_len, gt_len, reduction='none') # [N]
  843. # 方向角度损失
  844. pred_dir = line_direction(pred_lines)
  845. gt_dir = line_direction(gt_line_points)
  846. ang_loss = angle_loss_cosine(pred_dir, gt_dir) # [N]
  847. # 归一化每一项损失
  848. norm_iou = normalize_tensor(iou_loss)
  849. norm_len = normalize_tensor(length_diff)
  850. norm_ang = normalize_tensor(ang_loss)
  851. total = alpha * norm_iou + beta * norm_len + gamma * norm_ang
  852. losses.append(total)
  853. if not losses:
  854. return None
  855. return torch.mean(torch.cat(losses))
  856. def point_inference(x, point_boxes):
  857. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  858. points_probs = []
  859. points_scores = []
  860. boxes_per_image = [box.size(0) for box in point_boxes]
  861. x2 = x.split(boxes_per_image, dim=0)
  862. for xx, bb in zip(x2, point_boxes):
  863. point_prob,point_scores = heatmaps_to_points(xx, bb,num_points=1)
  864. points_probs.append(point_prob.unsqueeze(1))
  865. points_scores.append(point_scores)
  866. return points_probs,points_scores
  867. def circle_inference(x, point_boxes):
  868. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  869. points_probs = []
  870. points_scores = []
  871. boxes_per_image = [box.size(0) for box in point_boxes]
  872. x2 = x.split(boxes_per_image, dim=0)
  873. for xx, bb in zip(x2, point_boxes):
  874. point_prob,point_scores = heatmaps_to_circle_points(xx, bb,num_points=4)
  875. points_probs.append(point_prob.unsqueeze(1))
  876. points_scores.append(point_scores)
  877. return points_probs,points_scores
  878. def line_inference(x, line_boxes):
  879. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  880. lines_probs = []
  881. lines_scores = []
  882. boxes_per_image = [box.size(0) for box in line_boxes]
  883. x2 = x.split(boxes_per_image, dim=0)
  884. # x2:tuple 2 x2[0]:[1,3,1024,1024]
  885. # line_box: list:2 [1,4] [1.4] fasterrcnn kuang
  886. for xx, bb in zip(x2, line_boxes):
  887. line_prob, line_scores, = heatmaps_to_lines(xx, bb)
  888. lines_probs.append(line_prob)
  889. lines_scores.append(line_scores)
  890. return lines_probs, lines_scores
  891. def ins_inference(x, arc_boxes, th):
  892. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  893. points_probs = []
  894. points_scores = []
  895. print(f'arc_boxes:{len(arc_boxes)}')
  896. boxes_per_image = [box.size(0) for box in arc_boxes]
  897. print(f'arc boxes_per_image:{boxes_per_image}')
  898. x2 = x.split(boxes_per_image, dim=0)
  899. for xx, bb in zip(x2, arc_boxes):
  900. point_prob,point_scores = heatmaps_to_arc(xx, bb)
  901. points_probs.append(point_prob.unsqueeze(1))
  902. points_scores.append(point_scores)
  903. points_probs_tensor=torch.cat(points_probs)
  904. print(f'points_probs shape:{points_probs_tensor.shape}')
  905. feature_logits = x
  906. batch_size = feature_logits.shape[0]
  907. num_proposals = len(arc_boxes[0])
  908. results = [[torch.empty(0, 2) for _ in range(num_proposals)] for _ in range(batch_size)]
  909. proposals_list = arc_boxes[0] # [[tensor(...)]]
  910. for proposal_idx, proposal in enumerate(proposals_list):
  911. coords = proposal.tolist()
  912. x1, y1, x2, y2 = map(int, coords)
  913. x1 = max(0, x1)
  914. y1 = max(0, y1)
  915. x2 = min(feature_logits.shape[3], x2)
  916. y2 = min(feature_logits.shape[2], y2)
  917. for batch_idx in range(batch_size):
  918. region = feature_logits[batch_idx, :, y1:y2, x1:x2]
  919. mask = region > th
  920. coords = torch.nonzero(mask)
  921. if coords.numel() > 0:
  922. # 取 (y, x),然后转换为全局坐标 (x, y)
  923. local_coords = coords[:, [2, 1]] # (x, y)
  924. local_coords[:, 0] += x1
  925. local_coords[:, 1] += y1
  926. results[batch_idx][proposal_idx] = local_coords
  927. print(f're:{results}')
  928. return points_probs,points_scores,results
  929. import torch.nn.functional as F
  930. def heatmaps_to_arc(maps, rois, threshold=0.5, output_size=(128, 128)):
  931. """
  932. Args:
  933. maps: [N, 3, H, W] - full heatmaps
  934. rois: [N, 4] - bounding boxes
  935. threshold: float - binarization threshold
  936. output_size: resized size for uniform NMS
  937. Returns:
  938. masks: [N, 1, H, W] - binary ins aligned with input map
  939. scores: [N, 1] - count of non-zero pixels in each ins
  940. """
  941. N, _, H, W = maps.shape
  942. masks = torch.zeros((N, 1, H, W), dtype=torch.float32, device=maps.device)
  943. scores = torch.zeros((N, 1), dtype=torch.float32, device=maps.device)
  944. point_maps = maps[:, 0] # È¡µÚÒ»¸öͨµÀ [N, H, W]
  945. print(f"==> heatmaps_to_arc: maps.shape = {maps.shape}, rois.shape = {rois.shape}")
  946. for i in range(N):
  947. x1, y1, x2, y2 = rois[i].long()
  948. x1 = x1.clamp(0, W - 1)
  949. x2 = x2.clamp(0, W - 1)
  950. y1 = y1.clamp(0, H - 1)
  951. y2 = y2.clamp(0, H - 1)
  952. print(f"[{i}] roi: ({x1.item()}, {y1.item()}, {x2.item()}, {y2.item()})")
  953. if x2 <= x1 or y2 <= y1:
  954. print(f" Skipped invalid ROI at index {i}")
  955. continue
  956. roi_map = point_maps[i, y1:y2, x1:x2] # [h, w]
  957. print(f" roi_map.shape: {roi_map.shape}")
  958. if roi_map.numel() == 0:
  959. print(f" Skipped empty ROI at index {i}")
  960. continue
  961. # resize to uniform size
  962. roi_map_resized = F.interpolate(
  963. roi_map.unsqueeze(0).unsqueeze(0),
  964. size=output_size,
  965. mode='bilinear',
  966. align_corners=False
  967. ) # [1, 1, H, W]
  968. print(f" roi_map_resized.shape: {roi_map_resized.shape}")
  969. # NMS + threshold
  970. # nms_roi = non_maximum_suppression(roi_map_resized) # shape: [1, H, W]
  971. nms_roi = torch.sigmoid(roi_map_resized)
  972. bin_mask = (nms_roi >= threshold).float() # shape: [1, H, W]
  973. print(f" bin_mask.sum(): {bin_mask.sum().item()}")
  974. # resize back to original roi size
  975. h = int((y2 - y1).item())
  976. w = int((x2 - x1).item())
  977. # È·±£ bin_mask ÊÇ [1, 128, 128]
  978. assert bin_mask.dim() == 4, f"Expected 3D tensor [1, H, W], got {bin_mask.shape}"
  979. bin_mask_original_size = F.interpolate(
  980. bin_mask, # [1, 1, 128, 128]
  981. size=(h, w),
  982. mode='bilinear',
  983. align_corners=False
  984. )[0] # [1, h, w]
  985. bin_mask_original_size = bin_mask_original_size[0] # [h, w]
  986. masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size
  987. scores[i] = bin_mask_original_size.sum()
  988. #
  989. # bin_mask_original_size = F.interpolate(
  990. # # bin_mask.unsqueeze(0), # ? [1, 1, 128, 128]
  991. # bin_mask, # ? [1, 1, 128, 128]
  992. # size=(h, w),
  993. # mode='bilinear',
  994. # align_corners=False
  995. # )[0] # ? [1, h, w]
  996. #
  997. # masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size.squeeze()
  998. # scores[i] = bin_mask_original_size.sum()
  999. # plt.figure(figsize=(6, 6))
  1000. # plt.imshow(masks[i, 0].cpu().numpy(), cmap='gray')
  1001. # plt.title(f"Mask {i}, score={scores[i].item():.1f}")
  1002. # plt.axis('off')
  1003. # plt.show()
  1004. print(f" bin_mask_original_size.shape: {bin_mask_original_size.shape}, sum: {scores[i].item()}")
  1005. print(f"==> Done. Total valid masks: {(scores > 0).sum().item()} / {N}")
  1006. return masks, scores