loi_heads.py 78 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955
  1. from typing import Dict, List, Optional, Tuple
  2. import matplotlib.pyplot as plt
  3. import torch
  4. import torch.nn.functional as F
  5. import torchvision
  6. # from scipy.optimize import linear_sum_assignment
  7. from torch import nn, Tensor
  8. from libs.vision_libs.ops import boxes as box_ops, roi_align
  9. import libs.vision_libs.models.detection._utils as det_utils
  10. from collections import OrderedDict
  11. from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
  12. lines_point_pair_loss, features_align, line_inference, compute_ins_loss, ins_inference, compute_circle_loss, \
  13. circle_inference, arc_inference1
  14. from utils.data_process.show_prams import print_params
  15. def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
  16. # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
  17. """
  18. Computes the loss for Faster R-CNN.
  19. Args:
  20. class_logits (Tensor)
  21. box_regression (Tensor)
  22. labels (list[BoxList])
  23. regression_targets (Tensor)
  24. Returns:
  25. classification_loss (Tensor)
  26. box_loss (Tensor)
  27. """
  28. # print(f'compute fastrcnn_loss:{labels}')
  29. labels = torch.cat(labels, dim=0)
  30. regression_targets = torch.cat(regression_targets, dim=0)
  31. classification_loss = F.cross_entropy(class_logits, labels)
  32. # get indices that correspond to the regression targets for
  33. # the corresponding ground truth labels, to be used with
  34. # advanced indexing
  35. sampled_pos_inds_subset = torch.where(labels > 0)[0]
  36. labels_pos = labels[sampled_pos_inds_subset]
  37. N, num_classes = class_logits.shape
  38. box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
  39. box_loss = F.smooth_l1_loss(
  40. box_regression[sampled_pos_inds_subset, labels_pos],
  41. regression_targets[sampled_pos_inds_subset],
  42. beta=1 / 9,
  43. reduction="sum",
  44. )
  45. box_loss = box_loss / labels.numel()
  46. return classification_loss, box_loss
  47. def maskrcnn_inference(x, labels):
  48. # type: (Tensor, List[Tensor]) -> List[Tensor]
  49. """
  50. From the results of the CNN, post process the masks
  51. by taking the ins corresponding to the class with max
  52. probability (which are of fixed size and directly output
  53. by the CNN) and return the masks in the ins field of the BoxList.
  54. Args:
  55. x (Tensor): the ins logits
  56. labels (list[BoxList]): bounding boxes that are used as
  57. reference, one for ech image
  58. Returns:
  59. results (list[BoxList]): one BoxList for each image, containing
  60. the extra field ins
  61. """
  62. mask_prob = x.sigmoid()
  63. # select masks corresponding to the predicted classes
  64. num_masks = x.shape[0]
  65. boxes_per_image = [label.shape[0] for label in labels]
  66. labels = torch.cat(labels)
  67. index = torch.arange(num_masks, device=labels.device)
  68. mask_prob = mask_prob[index, labels][:, None]
  69. mask_prob = mask_prob.split(boxes_per_image, dim=0)
  70. return mask_prob
  71. def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
  72. # type: (Tensor, Tensor, Tensor, int) -> Tensor
  73. """
  74. Given segmentation masks and the bounding boxes corresponding
  75. to the location of the masks in the image, this function
  76. crops and resizes the masks in the position defined by the
  77. boxes. This prepares the masks for them to be fed to the
  78. loss computation as the targets.
  79. """
  80. matched_idxs = matched_idxs.to(boxes)
  81. rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
  82. gt_masks = gt_masks[:, None].to(rois)
  83. return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
  84. def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
  85. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  86. """
  87. Args:
  88. proposals (list[BoxList])
  89. mask_logits (Tensor)
  90. targets (list[BoxList])
  91. Return:
  92. mask_loss (Tensor): scalar tensor containing the loss
  93. """
  94. discretization_size = mask_logits.shape[-1]
  95. labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
  96. mask_targets = [
  97. project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
  98. ]
  99. labels = torch.cat(labels, dim=0)
  100. mask_targets = torch.cat(mask_targets, dim=0)
  101. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  102. # accept empty tensors, so handle it separately
  103. if mask_targets.numel() == 0:
  104. return mask_logits.sum() * 0
  105. mask_loss = F.binary_cross_entropy_with_logits(
  106. mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
  107. )
  108. return mask_loss
  109. def keypoints_to_heatmap(keypoints, rois, heatmap_size):
  110. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  111. offset_x = rois[:, 0]
  112. offset_y = rois[:, 1]
  113. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  114. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  115. offset_x = offset_x[:, None]
  116. offset_y = offset_y[:, None]
  117. scale_x = scale_x[:, None]
  118. scale_y = scale_y[:, None]
  119. x = keypoints[..., 0]
  120. y = keypoints[..., 1]
  121. x_boundary_inds = x == rois[:, 2][:, None]
  122. y_boundary_inds = y == rois[:, 3][:, None]
  123. x = (x - offset_x) * scale_x
  124. x = x.floor().long()
  125. y = (y - offset_y) * scale_y
  126. y = y.floor().long()
  127. x[x_boundary_inds] = heatmap_size - 1
  128. y[y_boundary_inds] = heatmap_size - 1
  129. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  130. vis = keypoints[..., 2] > 0
  131. valid = (valid_loc & vis).long()
  132. lin_ind = y * heatmap_size + x
  133. heatmaps = lin_ind * valid
  134. return heatmaps, valid
  135. def _onnx_heatmaps_to_keypoints(
  136. maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
  137. ):
  138. num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
  139. width_correction = widths_i / roi_map_width
  140. height_correction = heights_i / roi_map_height
  141. roi_map = F.interpolate(
  142. maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
  143. )[:, 0]
  144. w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
  145. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  146. x_int = pos % w
  147. y_int = (pos - x_int) // w
  148. x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
  149. dtype=torch.float32
  150. )
  151. y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
  152. dtype=torch.float32
  153. )
  154. xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
  155. xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
  156. xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
  157. xy_preds_i = torch.stack(
  158. [
  159. xy_preds_i_0.to(dtype=torch.float32),
  160. xy_preds_i_1.to(dtype=torch.float32),
  161. xy_preds_i_2.to(dtype=torch.float32),
  162. ],
  163. 0,
  164. )
  165. # TODO: simplify when indexing without rank will be supported by ONNX
  166. base = num_keypoints * num_keypoints + num_keypoints + 1
  167. ind = torch.arange(num_keypoints)
  168. ind = ind.to(dtype=torch.int64) * base
  169. end_scores_i = (
  170. roi_map.index_select(1, y_int.to(dtype=torch.int64))
  171. .index_select(2, x_int.to(dtype=torch.int64))
  172. .view(-1)
  173. .index_select(0, ind.to(dtype=torch.int64))
  174. )
  175. return xy_preds_i, end_scores_i
  176. @torch.jit._script_if_tracing
  177. def _onnx_heatmaps_to_keypoints_loop(
  178. maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
  179. ):
  180. xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  181. end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  182. for i in range(int(rois.size(0))):
  183. xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
  184. maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
  185. )
  186. xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
  187. end_scores = torch.cat(
  188. (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
  189. )
  190. return xy_preds, end_scores
  191. def heatmaps_to_keypoints(maps, rois):
  192. """Extract predicted keypoint locations from heatmaps. Output has shape
  193. (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
  194. for each keypoint.
  195. """
  196. # This function converts a discrete image coordinate in a HEATMAP_SIZE x
  197. # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
  198. # consistency with keypoints_to_heatmap_labels by using the conversion from
  199. # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
  200. # continuous coordinate.
  201. offset_x = rois[:, 0]
  202. offset_y = rois[:, 1]
  203. widths = rois[:, 2] - rois[:, 0]
  204. heights = rois[:, 3] - rois[:, 1]
  205. widths = widths.clamp(min=1)
  206. heights = heights.clamp(min=1)
  207. widths_ceil = widths.ceil()
  208. heights_ceil = heights.ceil()
  209. num_keypoints = maps.shape[1]
  210. if torchvision._is_tracing():
  211. xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
  212. maps,
  213. rois,
  214. widths_ceil,
  215. heights_ceil,
  216. widths,
  217. heights,
  218. offset_x,
  219. offset_y,
  220. torch.scalar_tensor(num_keypoints, dtype=torch.int64),
  221. )
  222. return xy_preds.permute(0, 2, 1), end_scores
  223. xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
  224. end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
  225. for i in range(len(rois)):
  226. roi_map_width = int(widths_ceil[i].item())
  227. roi_map_height = int(heights_ceil[i].item())
  228. width_correction = widths[i] / roi_map_width
  229. height_correction = heights[i] / roi_map_height
  230. roi_map = F.interpolate(
  231. maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
  232. )[:, 0]
  233. # roi_map_probs = scores_to_probs(roi_map.copy())
  234. w = roi_map.shape[2]
  235. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  236. x_int = pos % w
  237. y_int = torch.div(pos - x_int, w, rounding_mode="floor")
  238. # assert (roi_map_probs[k, y_int, x_int] ==
  239. # roi_map_probs[k, :, :].max())
  240. x = (x_int.float() + 0.5) * width_correction
  241. y = (y_int.float() + 0.5) * height_correction
  242. xy_preds[i, 0, :] = x + offset_x[i]
  243. xy_preds[i, 1, :] = y + offset_y[i]
  244. xy_preds[i, 2, :] = 1
  245. end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
  246. return xy_preds.permute(0, 2, 1), end_scores
  247. def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
  248. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  249. N, K, H, W = keypoint_logits.shape
  250. if H != W:
  251. raise ValueError(
  252. f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  253. )
  254. discretization_size = H
  255. heatmaps = []
  256. valid = []
  257. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
  258. kp = gt_kp_in_image[midx]
  259. heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
  260. heatmaps.append(heatmaps_per_image.view(-1))
  261. valid.append(valid_per_image.view(-1))
  262. keypoint_targets = torch.cat(heatmaps, dim=0)
  263. valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  264. valid = torch.where(valid)[0]
  265. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  266. # accept empty tensors, so handle it sepaartely
  267. if keypoint_targets.numel() == 0 or len(valid) == 0:
  268. return keypoint_logits.sum() * 0
  269. keypoint_logits = keypoint_logits.view(N * K, H * W)
  270. keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
  271. return keypoint_loss
  272. def keypointrcnn_inference(x, boxes):
  273. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  274. kp_probs = []
  275. kp_scores = []
  276. boxes_per_image = [box.size(0) for box in boxes]
  277. x2 = x.split(boxes_per_image, dim=0)
  278. for xx, bb in zip(x2, boxes):
  279. kp_prob, scores = heatmaps_to_keypoints(xx, bb)
  280. kp_probs.append(kp_prob)
  281. kp_scores.append(scores)
  282. return kp_probs, kp_scores
  283. def _onnx_expand_boxes(boxes, scale):
  284. # type: (Tensor, float) -> Tensor
  285. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  286. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  287. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  288. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  289. w_half = w_half.to(dtype=torch.float32) * scale
  290. h_half = h_half.to(dtype=torch.float32) * scale
  291. boxes_exp0 = x_c - w_half
  292. boxes_exp1 = y_c - h_half
  293. boxes_exp2 = x_c + w_half
  294. boxes_exp3 = y_c + h_half
  295. boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
  296. return boxes_exp
  297. # the next two functions should be merged inside Masker
  298. # but are kept here for the moment while we need them
  299. # temporarily for paste_mask_in_image
  300. def expand_boxes(boxes, scale):
  301. # type: (Tensor, float) -> Tensor
  302. if torchvision._is_tracing():
  303. return _onnx_expand_boxes(boxes, scale)
  304. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  305. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  306. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  307. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  308. w_half *= scale
  309. h_half *= scale
  310. boxes_exp = torch.zeros_like(boxes)
  311. boxes_exp[:, 0] = x_c - w_half
  312. boxes_exp[:, 2] = x_c + w_half
  313. boxes_exp[:, 1] = y_c - h_half
  314. boxes_exp[:, 3] = y_c + h_half
  315. return boxes_exp
  316. @torch.jit.unused
  317. def expand_masks_tracing_scale(M, padding):
  318. # type: (int, int) -> float
  319. return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
  320. def expand_masks(mask, padding):
  321. # type: (Tensor, int) -> Tuple[Tensor, float]
  322. M = mask.shape[-1]
  323. if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
  324. scale = expand_masks_tracing_scale(M, padding)
  325. else:
  326. scale = float(M + 2 * padding) / M
  327. padded_mask = F.pad(mask, (padding,) * 4)
  328. return padded_mask, scale
  329. def paste_mask_in_image(mask, box, im_h, im_w):
  330. # type: (Tensor, Tensor, int, int) -> Tensor
  331. TO_REMOVE = 1
  332. w = int(box[2] - box[0] + TO_REMOVE)
  333. h = int(box[3] - box[1] + TO_REMOVE)
  334. w = max(w, 1)
  335. h = max(h, 1)
  336. # Set shape to [batchxCxHxW]
  337. mask = mask.expand((1, 1, -1, -1))
  338. # Resize ins
  339. mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
  340. mask = mask[0][0]
  341. im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
  342. x_0 = max(box[0], 0)
  343. x_1 = min(box[2] + 1, im_w)
  344. y_0 = max(box[1], 0)
  345. y_1 = min(box[3] + 1, im_h)
  346. im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  347. return im_mask
  348. def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
  349. one = torch.ones(1, dtype=torch.int64)
  350. zero = torch.zeros(1, dtype=torch.int64)
  351. w = box[2] - box[0] + one
  352. h = box[3] - box[1] + one
  353. w = torch.max(torch.cat((w, one)))
  354. h = torch.max(torch.cat((h, one)))
  355. # Set shape to [batchxCxHxW]
  356. mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
  357. # Resize ins
  358. mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
  359. mask = mask[0][0]
  360. x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
  361. x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
  362. y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
  363. y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
  364. unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  365. # TODO : replace below with a dynamic padding when support is added in ONNX
  366. # pad y
  367. zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
  368. zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
  369. concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
  370. # pad x
  371. zeros_x0 = torch.zeros(concat_0.size(0), x_0)
  372. zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
  373. im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
  374. return im_mask
  375. @torch.jit._script_if_tracing
  376. def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
  377. res_append = torch.zeros(0, im_h, im_w)
  378. for i in range(masks.size(0)):
  379. mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
  380. mask_res = mask_res.unsqueeze(0)
  381. res_append = torch.cat((res_append, mask_res))
  382. return res_append
  383. def paste_masks_in_image(masks, boxes, img_shape, padding=1):
  384. # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
  385. masks, scale = expand_masks(masks, padding=padding)
  386. boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
  387. im_h, im_w = img_shape
  388. if torchvision._is_tracing():
  389. return _onnx_paste_masks_in_image_loop(
  390. masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
  391. )[:, None]
  392. res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
  393. if len(res) > 0:
  394. ret = torch.stack(res, dim=0)[:, None]
  395. else:
  396. ret = masks.new_empty((0, 1, im_h, im_w))
  397. return ret
  398. class RoIHeads(nn.Module):
  399. __annotations__ = {
  400. "box_coder": det_utils.BoxCoder,
  401. "proposal_matcher": det_utils.Matcher,
  402. "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
  403. }
  404. def __init__(
  405. self,
  406. box_roi_pool,
  407. box_head,
  408. box_predictor,
  409. # Faster R-CNN training
  410. fg_iou_thresh,
  411. bg_iou_thresh,
  412. batch_size_per_image,
  413. positive_fraction,
  414. bbox_reg_weights,
  415. # Faster R-CNN inference
  416. score_thresh,
  417. nms_thresh,
  418. detections_per_img,
  419. # Line
  420. line_roi_pool=None,
  421. line_head=None,
  422. line_predictor=None,
  423. # point parameters
  424. point_roi_pool=None,
  425. point_head=None,
  426. point_predictor=None,
  427. ins_head=None,
  428. ins_predictor=None,
  429. ins_roi_pool=None,
  430. # arc parameters
  431. arc_roi_pool=None,
  432. arc_head=None,
  433. arc_predictor=None,
  434. # Mask
  435. mask_roi_pool=None,
  436. mask_head=None,
  437. mask_predictor=None,
  438. keypoint_roi_pool=None,
  439. keypoint_head=None,
  440. keypoint_predictor=None,
  441. detect_point=True,
  442. detect_line=False,
  443. detect_arc=False,
  444. detect_ins=False,
  445. ):
  446. super().__init__()
  447. self.box_similarity = box_ops.box_iou
  448. # assign ground-truth boxes for each proposal
  449. self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
  450. self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
  451. if bbox_reg_weights is None:
  452. bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
  453. self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
  454. self.box_roi_pool = box_roi_pool
  455. self.box_head = box_head
  456. self.box_predictor = box_predictor
  457. self.score_thresh = score_thresh
  458. self.nms_thresh = nms_thresh
  459. self.detections_per_img = detections_per_img
  460. self.line_roi_pool = line_roi_pool
  461. self.line_head = line_head
  462. self.line_predictor = line_predictor
  463. self.point_roi_pool = point_roi_pool
  464. self.point_head = point_head
  465. self.point_predictor = point_predictor
  466. self.arc_roi_pool = arc_roi_pool
  467. self.arc_head = arc_head
  468. self.arc_predictor = arc_predictor
  469. self.ins_roi_pool = ins_roi_pool
  470. self.ins_head = ins_head
  471. self.ins_predictor = ins_predictor
  472. self.mask_roi_pool = mask_roi_pool
  473. self.mask_head = mask_head
  474. self.mask_predictor = mask_predictor
  475. self.keypoint_roi_pool = keypoint_roi_pool
  476. self.keypoint_head = keypoint_head
  477. self.keypoint_predictor = keypoint_predictor
  478. self.detect_point =detect_point
  479. self.detect_line =detect_line
  480. self.detect_arc =detect_arc
  481. self.detect_ins=detect_ins
  482. self.channel_compress = nn.Sequential(
  483. nn.Conv2d(256, 8, kernel_size=1),
  484. nn.BatchNorm2d(8),
  485. nn.ReLU(inplace=True)
  486. )
  487. def has_mask(self):
  488. if self.mask_roi_pool is None:
  489. return False
  490. if self.mask_head is None:
  491. return False
  492. if self.mask_predictor is None:
  493. return False
  494. return True
  495. def has_keypoint(self):
  496. if self.keypoint_roi_pool is None:
  497. return False
  498. if self.keypoint_head is None:
  499. return False
  500. if self.keypoint_predictor is None:
  501. return False
  502. return True
  503. def has_line(self):
  504. # if self.line_roi_pool is None:
  505. # return False
  506. if self.line_head is None:
  507. return False
  508. # if self.line_predictor is None:
  509. # return False
  510. return True
  511. def has_point(self):
  512. # if self.line_roi_pool is None:
  513. # return False
  514. if self.point_head is None:
  515. return False
  516. # if self.line_predictor is None:
  517. # return False
  518. return True
  519. def has_arc(self):
  520. # if self.line_roi_pool is None:
  521. # return False
  522. if self.arc_head is None:
  523. return False
  524. # if self.line_predictor is None:
  525. # return False
  526. return True
  527. def has_ins(self):
  528. # if self.line_roi_pool is None:
  529. # return False
  530. if self.ins_head is None:
  531. return False
  532. # if self.line_predictor is None:
  533. # return False
  534. return True
  535. def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
  536. # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  537. matched_idxs = []
  538. labels = []
  539. for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
  540. if gt_boxes_in_image.numel() == 0:
  541. # Background image
  542. device = proposals_in_image.device
  543. clamped_matched_idxs_in_image = torch.zeros(
  544. (proposals_in_image.shape[0],), dtype=torch.int64, device=device
  545. )
  546. labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
  547. else:
  548. # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
  549. match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
  550. matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
  551. clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
  552. labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
  553. labels_in_image = labels_in_image.to(dtype=torch.int64)
  554. # Label background (below the low threshold)
  555. bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
  556. labels_in_image[bg_inds] = 0
  557. # Label ignore proposals (between low and high thresholds)
  558. ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
  559. labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
  560. matched_idxs.append(clamped_matched_idxs_in_image)
  561. labels.append(labels_in_image)
  562. return matched_idxs, labels
  563. def subsample(self, labels):
  564. # type: (List[Tensor]) -> List[Tensor]
  565. sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  566. sampled_inds = []
  567. for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
  568. img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
  569. sampled_inds.append(img_sampled_inds)
  570. return sampled_inds
  571. def add_gt_proposals(self, proposals, gt_boxes):
  572. # type: (List[Tensor], List[Tensor]) -> List[Tensor]
  573. proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
  574. return proposals
  575. def check_targets(self, targets):
  576. # type: (Optional[List[Dict[str, Tensor]]]) -> None
  577. if targets is None:
  578. raise ValueError("targets should not be None")
  579. if not all(["boxes" in t for t in targets]):
  580. raise ValueError("Every element of targets should have a boxes key")
  581. if not all(["labels" in t for t in targets]):
  582. raise ValueError("Every element of targets should have a labels key")
  583. if self.has_mask():
  584. if not all(["masks" in t for t in targets]):
  585. raise ValueError("Every element of targets should have a masks key")
  586. def select_training_samples(
  587. self,
  588. proposals, # type: List[Tensor]
  589. targets, # type: Optional[List[Dict[str, Tensor]]]
  590. ):
  591. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
  592. self.check_targets(targets)
  593. if targets is None:
  594. raise ValueError("targets should not be None")
  595. dtype = proposals[0].dtype
  596. device = proposals[0].device
  597. gt_boxes = [t["boxes"].to(dtype) for t in targets]
  598. gt_labels = [t["labels"] for t in targets]
  599. # append ground-truth bboxes to propos
  600. proposals = self.add_gt_proposals(proposals, gt_boxes)
  601. # get matching gt indices for each proposal
  602. matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
  603. # sample a fixed proportion of positive-negative proposals
  604. sampled_inds = self.subsample(labels)
  605. matched_gt_boxes = []
  606. num_images = len(proposals)
  607. for img_id in range(num_images):
  608. img_sampled_inds = sampled_inds[img_id]
  609. proposals[img_id] = proposals[img_id][img_sampled_inds]
  610. labels[img_id] = labels[img_id][img_sampled_inds]
  611. matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
  612. gt_boxes_in_image = gt_boxes[img_id]
  613. if gt_boxes_in_image.numel() == 0:
  614. gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
  615. matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
  616. regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
  617. return proposals, matched_idxs, labels, regression_targets
  618. def postprocess_detections(
  619. self,
  620. class_logits, # type: Tensor
  621. box_regression, # type: Tensor
  622. proposals, # type: List[Tensor]
  623. image_shapes, # type: List[Tuple[int, int]]
  624. ):
  625. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
  626. device = class_logits.device
  627. num_classes = class_logits.shape[-1]
  628. boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
  629. pred_boxes = self.box_coder.decode(box_regression, proposals)
  630. pred_scores = F.softmax(class_logits, -1)
  631. pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
  632. pred_scores_list = pred_scores.split(boxes_per_image, 0)
  633. all_boxes = []
  634. all_scores = []
  635. all_labels = []
  636. for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
  637. boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
  638. # create labels for each prediction
  639. labels = torch.arange(num_classes, device=device)
  640. labels = labels.view(1, -1).expand_as(scores)
  641. # remove predictions with the background label
  642. boxes = boxes[:, 1:]
  643. scores = scores[:, 1:]
  644. labels = labels[:, 1:]
  645. # batch everything, by making every class prediction be a separate instance
  646. boxes = boxes.reshape(-1, 4)
  647. scores = scores.reshape(-1)
  648. labels = labels.reshape(-1)
  649. # remove low scoring boxes
  650. inds = torch.where(scores > self.score_thresh)[0]
  651. boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
  652. # remove empty boxes
  653. keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
  654. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  655. # non-maximum suppression, independently done per class
  656. keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
  657. # keep only topk scoring predictions
  658. keep = keep[: self.detections_per_img]
  659. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  660. all_boxes.append(boxes)
  661. all_scores.append(scores)
  662. all_labels.append(labels)
  663. return all_boxes, all_scores, all_labels
  664. def forward(
  665. self,
  666. features, # type: Dict[str, Tensor]
  667. proposals, # type: List[Tensor]
  668. image_shapes, # type: List[Tuple[int, int]]
  669. targets=None, # type: Optional[List[Dict[str, Tensor]]]
  670. ):
  671. # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
  672. """
  673. Args:
  674. features (List[Tensor])
  675. proposals (List[Tensor[N, 4]])
  676. image_shapes (List[Tuple[H, W]])
  677. targets (List[Dict])
  678. """
  679. print(f'roihead forward!!!')
  680. if targets is not None:
  681. for t in targets:
  682. # TODO: https://github.com/pytorch/pytorch/issues/26731
  683. floating_point_types = (torch.float, torch.double, torch.half)
  684. if not t["boxes"].dtype in floating_point_types:
  685. raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
  686. if not t["labels"].dtype == torch.int64:
  687. raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
  688. if self.has_keypoint():
  689. if not t["keypoints"].dtype == torch.float32:
  690. raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
  691. if self.training:
  692. proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  693. else:
  694. if targets is not None:
  695. proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  696. else:
  697. labels = None
  698. regression_targets = None
  699. matched_idxs = None
  700. device=features['0'].device
  701. box_features = self.box_roi_pool(features, proposals, image_shapes)
  702. box_features = self.box_head(box_features)
  703. class_logits, box_regression = self.box_predictor(box_features)
  704. result: List[Dict[str, torch.Tensor]] = []
  705. losses = {}
  706. # _, C, H, W = features['0'].shape # 忽略 batch_size,因为我们只关心 C, H, W
  707. if self.training:
  708. if labels is None:
  709. raise ValueError("labels cannot be None")
  710. if regression_targets is None:
  711. raise ValueError("regression_targets cannot be None")
  712. print(f'boxes compute losses')
  713. loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  714. losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  715. else:
  716. if targets is not None:
  717. loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  718. losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  719. boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals,
  720. image_shapes)
  721. num_images = len(boxes)
  722. for i in range(num_images):
  723. result.append(
  724. {
  725. "boxes": boxes[i],
  726. "labels": labels[i],
  727. "scores": scores[i],
  728. }
  729. )
  730. if self.has_line() and self.detect_line:
  731. print(f'roi_heads forward has_line()!!!!')
  732. # print(f'labels:{labels}')
  733. line_proposals = [p["boxes"] for p in result]
  734. point_proposals = [p["boxes"] for p in result]
  735. print(f'boxes_proposals:{len(line_proposals)}')
  736. # if line_proposals is None or len(line_proposals) == 0:
  737. # # 返回空特征或者跳过该部分计算
  738. # return torch.empty(0, C, H, W).to(features['0'].device)
  739. if self.training:
  740. # during training, only focus on positive boxes
  741. num_images = len(proposals)
  742. print(f'num_images:{num_images}')
  743. line_proposals = []
  744. point_proposals = []
  745. arc_proposals = []
  746. pos_matched_idxs = []
  747. line_pos_matched_idxs = []
  748. point_pos_matched_idxs = []
  749. if matched_idxs is None:
  750. raise ValueError("if in trainning, matched_idxs should not be None")
  751. for img_id in range(num_images):
  752. pos = torch.where(labels[img_id] > 0)[0]
  753. line_pos=torch.where(labels[img_id] ==2)[0]
  754. # point_pos=torch.where(labels[img_id] ==1)[0]
  755. line_proposals.append(proposals[img_id][line_pos])
  756. # point_proposals.append(proposals[img_id][point_pos])
  757. line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
  758. # point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
  759. # pos_matched_idxs.append(matched_idxs[img_id][pos])
  760. else:
  761. if targets is not None:
  762. pos_matched_idxs = []
  763. num_images = len(proposals)
  764. line_proposals = []
  765. line_pos_matched_idxs = []
  766. print(f'val num_images:{num_images}')
  767. if matched_idxs is None:
  768. raise ValueError("if in trainning, matched_idxs should not be None")
  769. for img_id in range(num_images):
  770. # pos = torch.where(labels[img_id] > 0)[0]
  771. line_pos = torch.where(labels[img_id] == 2)[0]
  772. line_proposals.append(proposals[img_id][line_pos])
  773. line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
  774. else:
  775. pos_matched_idxs = None
  776. line_proposals_valid=self.check_proposals(line_proposals)
  777. if line_proposals_valid:
  778. feature_logits = self.line_forward3(features, image_shapes, line_proposals)
  779. loss_line = None
  780. loss_line_iou =None
  781. if self.training:
  782. if targets is None or pos_matched_idxs is None:
  783. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  784. gt_lines = [t["lines"] for t in targets if "lines" in t]
  785. # print(f'gt_lines:{gt_lines[0].shape}')
  786. h, w = targets[0]["img_size"]
  787. img_size = h
  788. gt_lines_tensor=torch.zeros(0,0)
  789. if len(gt_lines)>0:
  790. gt_lines_tensor = torch.cat(gt_lines)
  791. print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
  792. if gt_lines_tensor.shape[0]>0 :
  793. print(f'start to lines_point_pair_loss')
  794. loss_line = lines_point_pair_loss(
  795. feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
  796. )
  797. loss_line_iou = line_iou_loss(feature_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
  798. if loss_line is None:
  799. print(f'loss_line is None111')
  800. loss_line = torch.tensor(0.0, device=device)
  801. if loss_line_iou is None:
  802. print(f'loss_line_iou is None111')
  803. loss_line_iou = torch.tensor(0.0, device=device)
  804. loss_line = {"loss_line": loss_line}
  805. loss_line_iou = {'loss_line_iou': loss_line_iou}
  806. else:
  807. if targets is not None:
  808. h, w = targets[0]["img_size"]
  809. img_size = h
  810. gt_lines = [t["lines"] for t in targets if "lines" in t]
  811. gt_lines_tensor = torch.zeros(0, 0)
  812. if len(gt_lines)>0:
  813. gt_lines_tensor = torch.cat(gt_lines)
  814. if gt_lines_tensor.shape[0] > 0 and feature_logits is not None:
  815. loss_line = lines_point_pair_loss(
  816. feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
  817. )
  818. print(f'compute_line_loss:{loss_line}')
  819. loss_line_iou = line_iou_loss(feature_logits , line_proposals, gt_lines, line_pos_matched_idxs,
  820. img_size)
  821. if loss_line is None:
  822. print(f'loss_line is None')
  823. loss_line=torch.tensor(0.0,device=device)
  824. if loss_line_iou is None:
  825. print(f'loss_line_iou is None')
  826. loss_line_iou=torch.tensor(0.0,device=device)
  827. loss_line = {"loss_line": loss_line}
  828. loss_line_iou = {'loss_line_iou': loss_line_iou}
  829. else:
  830. loss_line = {}
  831. loss_line_iou = {}
  832. if feature_logits is None or line_proposals is None:
  833. raise ValueError(
  834. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  835. )
  836. if feature_logits is not None:
  837. lines_probs, lines_scores = line_inference(feature_logits,line_proposals)
  838. for masks, kps, r in zip(lines_probs, lines_scores, result):
  839. r["lines"] = masks
  840. r["lines_scores"] = kps
  841. print(f'loss_line11111:{loss_line}')
  842. losses.update(loss_line)
  843. losses.update(loss_line_iou)
  844. print(f'losses:{losses}')
  845. if self.has_point() and self.detect_point:
  846. print(f'roi_heads forward has_point()!!!!')
  847. # print(f'labels:{labels}')
  848. point_proposals = [p["boxes"] for p in result]
  849. print(f'boxes_proposals:{len(point_proposals)}')
  850. # if line_proposals is None or len(line_proposals) == 0:
  851. # # 返回空特征或者跳过该部分计算
  852. # return torch.empty(0, C, H, W).to(features['0'].device)
  853. if self.training:
  854. # during training, only focus on positive boxes
  855. num_images = len(proposals)
  856. print(f'num_images:{num_images}')
  857. point_proposals = []
  858. point_pos_matched_idxs = []
  859. if matched_idxs is None:
  860. raise ValueError("if in trainning, matched_idxs should not be None")
  861. for img_id in range(num_images):
  862. point_pos=torch.where(labels[img_id] ==1)[0]
  863. point_proposals.append(proposals[img_id][point_pos])
  864. point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
  865. else:
  866. if targets is not None:
  867. num_images = len(proposals)
  868. point_proposals = []
  869. point_pos_matched_idxs = []
  870. print(f'val num_images:{num_images}')
  871. if matched_idxs is None:
  872. raise ValueError("if in trainning, matched_idxs should not be None")
  873. for img_id in range(num_images):
  874. point_pos = torch.where(labels[img_id] == 1)[0]
  875. point_proposals.append(proposals[img_id][point_pos])
  876. point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
  877. else:
  878. pos_matched_idxs = None
  879. point_proposals_valid = self.check_proposals(point_proposals)
  880. if point_proposals_valid:
  881. feature_logits = self.point_forward1(features, image_shapes, point_proposals)
  882. loss_point=None
  883. if self.training:
  884. if targets is None or point_pos_matched_idxs is None:
  885. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  886. gt_points = [t["points"] for t in targets if "points" in t]
  887. print(f'gt_points:{gt_points[0].shape}')
  888. h, w = targets[0]["img_size"]
  889. img_size = h
  890. gt_points_tensor = torch.zeros(0, 0)
  891. if len(gt_points) > 0:
  892. gt_points_tensor = torch.cat(gt_points)
  893. print(f'gt_points_tensor:{gt_points_tensor.shape}')
  894. if gt_points_tensor.shape[0] > 0:
  895. print(f'start to compute point_loss')
  896. loss_point=compute_point_loss(feature_logits,point_proposals,gt_points,point_pos_matched_idxs)
  897. if loss_point is None:
  898. print(f'loss_point is None111')
  899. loss_point = torch.tensor(0.0, device=device)
  900. loss_point = {"loss_point": loss_point}
  901. else:
  902. if targets is not None:
  903. h, w = targets[0]["img_size"]
  904. img_size = h
  905. gt_points = [t["points"] for t in targets if "points" in t]
  906. gt_points_tensor = torch.zeros(0, 0)
  907. if len(gt_points) > 0:
  908. gt_points_tensor = torch.cat(gt_points)
  909. print(f'gt_points_tensor:{gt_points_tensor.shape}')
  910. if gt_points_tensor.shape[0] > 0:
  911. print(f'start to compute point_loss')
  912. loss_point = compute_point_loss(feature_logits, point_proposals, gt_points,
  913. point_pos_matched_idxs)
  914. if loss_point is None:
  915. print(f'loss_point is None111')
  916. loss_point = torch.tensor(0.0, device=device)
  917. loss_point = {"loss_point": loss_point}
  918. else:
  919. loss_point = {}
  920. if feature_logits is None or point_proposals is None:
  921. raise ValueError(
  922. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  923. )
  924. if feature_logits is not None:
  925. points_probs, points_scores = point_inference(feature_logits,point_proposals)
  926. for masks, kps, r in zip(points_probs, points_scores, result):
  927. r["points"] = masks
  928. r["points_scores"] = kps
  929. print(f'loss_point:{loss_point}')
  930. losses.update(loss_point)
  931. print(f'losses:{losses}')
  932. if self.has_arc() and self.detect_arc:
  933. print(f'roi_heads forward has_arc()!!!!')
  934. # print(f'labels:{labels}')
  935. arc_proposals = [p["boxes"] for p in result]
  936. print(f'boxes_proposals:{len(arc_proposals)}')
  937. print(f'boxes_proposals:{len(arc_proposals)}')
  938. # if line_proposals is None or len(line_proposals) == 0:
  939. # # 返回空特征或者跳过该部分计算
  940. # return torch.empty(0, C, H, W).to(features['0'].device)
  941. if self.training:
  942. # during training, only focus on positive boxes
  943. num_images = len(proposals)
  944. print(f'num_images:{num_images}')
  945. arc_proposals = []
  946. arc_pos_matched_idxs = []
  947. if matched_idxs is None:
  948. raise ValueError("if in trainning, matched_idxs should not be None")
  949. for img_id in range(num_images):
  950. arc_pos=torch.where(labels[img_id] ==3)[0]
  951. arc_proposals.append(proposals[img_id][arc_pos])
  952. arc_pos_matched_idxs.append(matched_idxs[img_id][arc_pos])
  953. else:
  954. if targets is not None:
  955. num_images = len(proposals)
  956. arc_proposals = []
  957. arc_pos_matched_idxs = []
  958. print(f'val num_images:{num_images}')
  959. if matched_idxs is None:
  960. raise ValueError("if in trainning, matched_idxs should not be None")
  961. for img_id in range(num_images):
  962. arc_pos = torch.where(labels[img_id] == 3)[0]
  963. arc_proposals.append(proposals[img_id][arc_pos])
  964. arc_pos_matched_idxs.append(matched_idxs[img_id][arc_pos])
  965. else:
  966. arc_pos_matched_idxs = None
  967. arc_proposals_valid=self.check_proposals(arc_proposals)
  968. if arc_proposals_valid:
  969. feature_logits = self.arc_forward1(features, image_shapes, arc_proposals)
  970. loss_arc=None
  971. if self.training:
  972. if targets is None or arc_pos_matched_idxs is None:
  973. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  974. gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
  975. print(f'gt_arcs:{gt_arcs[0].shape}')
  976. h, w = targets[0]["img_size"]
  977. img_size = h
  978. if len(gt_arcs) > 0 and feature_logits is not None:
  979. loss_arc = compute_ins_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
  980. if loss_arc is None:
  981. print(f'loss_arc is None111')
  982. loss_arc = torch.tensor(0.0, device=device)
  983. loss_arc = {"loss_arc": loss_arc}
  984. else:
  985. if targets is not None:
  986. h, w = targets[0]["img_size"]
  987. img_size = h
  988. gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
  989. print(f'gt_arcs:{gt_arcs[0].shape}')
  990. h, w = targets[0]["img_size"]
  991. img_size = h
  992. if len(gt_arcs) > 0 and feature_logits is not None:
  993. print(f'start to compute arc_loss')
  994. loss_arc = compute_ins_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
  995. if loss_arc is None:
  996. print(f'loss_arc is None111')
  997. loss_arc = torch.tensor(0.0, device=device)
  998. loss_arc = {"loss_arc": loss_arc}
  999. else:
  1000. loss_arc = {}
  1001. if feature_logits is None or arc_proposals is None:
  1002. # raise ValueError(
  1003. # "both arc_feature_logits and arc_proposals should not be None when not in training mode"
  1004. # )
  1005. print(f'error :both arc_feature_logits and arc_proposals should not be None when not in training mode"')
  1006. pass
  1007. if feature_logits is not None and arc_proposals is not None:
  1008. arcs_probs, arcs_scores, arcs_point = ins_inference(feature_logits, arc_proposals, th=0)
  1009. for masks, kps, kp, r in zip(arcs_probs, arcs_scores, arcs_point, result):
  1010. # r["arcs"] = keypoint_prob
  1011. r["arcs"] = feature_logits
  1012. r["arcs_scores"] = kps
  1013. r["arcs_point"] = feature_logits
  1014. # print(f'loss_point:{loss_point}')
  1015. losses.update(loss_arc)
  1016. print(f'losses:{losses}')
  1017. if self.has_ins and self.detect_ins:
  1018. print(f'roi_heads forward has_circle()!!!!')
  1019. # print(f'labels:{labels}')
  1020. ins_proposals = [p["boxes"] for p in result]
  1021. print(f'boxes_proposals:{len(ins_proposals)}')
  1022. # if line_proposals is None or len(line_proposals) == 0:
  1023. # # 返回空特征或者跳过该部分计算
  1024. # return torch.empty(0, C, H, W).to(features['0'].device)
  1025. if self.training:
  1026. # during training, only focus on positive boxes
  1027. num_images = len(proposals)
  1028. print(f'num_images:{num_images}')
  1029. ins_proposals = []
  1030. ins_pos_matched_idxs = []
  1031. if matched_idxs is None:
  1032. raise ValueError("if in trainning, matched_idxs should not be None")
  1033. for img_id in range(num_images):
  1034. # circle_pos = torch.where(labels[img_id] == 4)[0]
  1035. # ins_proposals.append(proposals[img_id][circle_pos])
  1036. # ins_pos_matched_idxs.append(matched_idxs[img_id][circle_pos])
  1037. circle_pos = torch.where(labels[img_id] == 4)[0]
  1038. circle_pos = circle_pos.flatten()
  1039. idxs = circle_pos.detach().cpu().tolist()
  1040. num_prop = len(proposals[img_id])
  1041. for idx in idxs:
  1042. if idx < 0 or idx >= num_prop:
  1043. raise RuntimeError(
  1044. f"Index out of bounds: circle_pos={idx}, but proposals len={num_prop}, "
  1045. f"img_id={img_id}"
  1046. )
  1047. ins_proposals.append(
  1048. proposals[img_id][idxs]
  1049. )
  1050. ins_pos_matched_idxs.append(
  1051. matched_idxs[img_id][idxs]
  1052. )
  1053. else:
  1054. if targets is not None:
  1055. num_images = len(proposals)
  1056. ins_proposals = []
  1057. ins_pos_matched_idxs = []
  1058. print(f'val num_images:{num_images}')
  1059. if matched_idxs is None:
  1060. raise ValueError("if in trainning, matched_idxs should not be None")
  1061. for img_id in range(num_images):
  1062. # circle_pos = torch.where(labels[img_id] == 4)[0]
  1063. # ins_proposals.append(proposals[img_id][circle_pos])
  1064. # ins_pos_matched_idxs.append(matched_idxs[img_id][circle_pos])
  1065. circle_pos = torch.where(labels[img_id] == 4)[0]
  1066. circle_pos = circle_pos.flatten()
  1067. idxs = circle_pos.detach().cpu().tolist()
  1068. num_prop = len(proposals[img_id])
  1069. for idx in idxs:
  1070. if idx < 0 or idx >= num_prop:
  1071. raise RuntimeError(
  1072. f"Index out of bounds: circle_pos={idx}, but proposals len={num_prop}, "
  1073. f"img_id={img_id}"
  1074. )
  1075. ins_proposals.append(
  1076. proposals[img_id][idxs]
  1077. )
  1078. ins_pos_matched_idxs.append(
  1079. matched_idxs[img_id][idxs]
  1080. )
  1081. else:
  1082. pos_matched_idxs = None
  1083. # circle_proposals_tensor=torch.cat(circle_proposals)
  1084. ins_proposals_valid = self.check_proposals(ins_proposals)
  1085. print(f"self.train{self.training}")
  1086. print(f"self.val{ins_proposals_valid}")
  1087. if ins_proposals_valid:
  1088. print(f'features from backbone:{features['0'].shape}')
  1089. feature_logits = self.ins_forward1(features, image_shapes, ins_proposals)
  1090. # ins_masks, ins_scores, circle_points = ins_inference(feature_logits,
  1091. # ins_proposals, th=0)
  1092. arc_equation = self.arc_equation_head(feature_logits) # [proposal和,7]
  1093. loss_ins = None
  1094. loss_ins_extra=None
  1095. loss_arc_equation = None
  1096. loss_arc_ends = None
  1097. if self.training:
  1098. print("circle loss!!!!!!")
  1099. if targets is None or ins_pos_matched_idxs is None:
  1100. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  1101. gt_inses = [t["circle_masks"] for t in targets if "circle_masks" in t]
  1102. gt_labels = [t["labels"] for t in targets]
  1103. gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
  1104. gt_mask_ends = [t["mask_ends"] for t in targets if "mask_ends" in t]
  1105. gt_mask_params = [t["mask_params"] for t in targets if "mask_params" in t]
  1106. # print(f'gt_ins:{gt_inses[0].shape}')
  1107. h, w = targets[0]["img_size"]
  1108. img_size = h
  1109. gt_ins_tensor = torch.zeros(0, 0)
  1110. if len(gt_inses) > 0:
  1111. print_params(gt_inses)
  1112. gt_ins_tensor = torch.cat(gt_inses)
  1113. print(f'gt_ins_tensor:{gt_ins_tensor.shape}')
  1114. if gt_ins_tensor.shape[0] > 0:
  1115. print(f'start to compute circle_loss')
  1116. loss_ins = compute_ins_loss(feature_logits, ins_proposals, gt_inses,ins_pos_matched_idxs)
  1117. total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,ins_proposals,gt_mask_ends,gt_mask_params,ins_pos_matched_idxs,labels)
  1118. loss_arc_ends = loss_arc_ends
  1119. if loss_arc_equation is None:
  1120. print(f'loss_arc_equation is None')
  1121. loss_arc_equation = torch.tensor(0.0, device=device)
  1122. if loss_arc_ends is None:
  1123. print(f'loss_arc_ends is None')
  1124. loss_arc_ends = torch.tensor(0.0, device=device)
  1125. if loss_ins is None:
  1126. print(f'loss_ins is None111')
  1127. loss_ins = torch.tensor(0.0, device=device)
  1128. if loss_ins_extra is None:
  1129. print(f'loss_ins_extra is None111')
  1130. loss_ins_extra = torch.tensor(0.0, device=device)
  1131. loss_ins = {"loss_ins": loss_ins}
  1132. loss_ins_extra = {"loss_ins_extra": loss_ins_extra}
  1133. loss_arc_equation = {"loss_arc_equation": loss_arc_equation}
  1134. loss_arc_ends = {"loss_arc_ends": loss_arc_ends}
  1135. else:
  1136. if targets is not None:
  1137. h, w = targets[0]["img_size"]
  1138. img_size = h
  1139. gt_inses = [t["circle_masks"] for t in targets if "circle_masks" in t]
  1140. gt_labels = [t["labels"] for t in targets]
  1141. gt_mask_ends = [t["mask_ends"] for t in targets if "mask_ends" in t]
  1142. gt_mask_params = [t["mask_params"] for t in targets if "mask_params" in t]
  1143. gt_ins_tensor = torch.zeros(0, 0)
  1144. if len(gt_inses) > 0:
  1145. gt_ins_tensor = torch.cat(gt_inses)
  1146. print(f'gt_ins_tensor:{gt_ins_tensor.shape}')
  1147. if gt_ins_tensor.shape[0] > 0:
  1148. print(f'start to compute circle_loss')
  1149. loss_ins = compute_ins_loss(feature_logits, ins_proposals, gt_inses,
  1150. ins_pos_matched_idxs)
  1151. total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,ins_proposals,gt_mask_ends,gt_mask_params,ins_pos_matched_idxs,labels)
  1152. loss_arc_ends = loss_arc_ends
  1153. # loss_ins_extra = compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles,circle_pos_matched_idxs)
  1154. if loss_ins is None:
  1155. print(f'loss_ins is None111')
  1156. loss_ins = torch.tensor(0.0, device=device)
  1157. if loss_ins_extra is None:
  1158. print(f'loss_ins_extra is None111')
  1159. loss_ins_extra = torch.tensor(0.0, device=device)
  1160. if loss_arc_equation is None:
  1161. print(f'loss_arc_equation is None')
  1162. loss_arc_equation = torch.tensor(0.0, device=device)
  1163. if loss_arc_ends is None:
  1164. print(f'loss_arc_ends is None')
  1165. loss_arc_ends = torch.tensor(0.0, device=device)
  1166. if loss_ins is None:
  1167. print(f'loss_ins is None111')
  1168. loss_ins = torch.tensor(0.0, device=device)
  1169. if loss_ins_extra is None:
  1170. print(f'loss_ins_extra is None111')
  1171. loss_ins_extra = torch.tensor(0.0, device=device)
  1172. loss_ins = {"loss_ins": loss_ins}
  1173. loss_ins_extra = {"loss_ins_extra": loss_ins_extra}
  1174. loss_arc_equation = {"loss_arc_equation": loss_arc_equation}
  1175. loss_arc_ends = {"loss_arc_ends": loss_arc_ends}
  1176. else:
  1177. loss_ins = {}
  1178. loss_ins_extra = {}
  1179. loss_arc_equation = {}
  1180. loss_arc_ends = {}
  1181. if feature_logits is None or ins_proposals is None:
  1182. raise ValueError(
  1183. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  1184. )
  1185. if feature_logits is not None:
  1186. ins_masks, ins_scores, circle_points = ins_inference(feature_logits,
  1187. ins_proposals, th=0)
  1188. arc7, arc_scores = arc_inference1(arc_equation, feature_logits, ins_proposals, 0.5)
  1189. for arc_, arc_score, r in zip(arc7, arc_scores, result):
  1190. r["arcs"] = arc_
  1191. r["arc_scores"] = arc_score
  1192. # print(f'circles_probs:{circles_probs.shape}, circles_scores:{circles_scores.shape}')
  1193. proposals_per_image = [box.size(0) for box in ins_proposals]
  1194. print(f'ins_proposals_per_image:{proposals_per_image}')
  1195. feature_logits_props = []
  1196. start_idx = 0
  1197. for num_p in proposals_per_image:
  1198. current_features = feature_logits[start_idx:start_idx + num_p]
  1199. merged_feature = torch.sum(current_features, dim=0, keepdim=True)
  1200. feature_logits_props.append(merged_feature)
  1201. start_idx += num_p
  1202. for masks, kps, r, f in zip(ins_masks, ins_scores, result,
  1203. feature_logits_props):
  1204. r["ins_masks"] = masks
  1205. r["ins_scores"] = kps
  1206. print(f'ins feature map:{f.shape}')
  1207. r["features"] = f.squeeze(0)
  1208. print(f'loss_ins:{loss_ins}')
  1209. print(f'loss_ins_extra:{loss_ins_extra}')
  1210. losses.update(loss_ins)
  1211. losses.update(loss_ins_extra)
  1212. losses.update(loss_arc_equation)
  1213. losses.update(loss_arc_ends)
  1214. print(f'losses:{losses}')
  1215. if self.has_mask():
  1216. mask_proposals = [p["boxes"] for p in result]
  1217. if self.training:
  1218. if matched_idxs is None:
  1219. raise ValueError("if in training, matched_idxs should not be None")
  1220. # during training, only focus on positive boxes
  1221. num_images = len(proposals)
  1222. mask_proposals = []
  1223. pos_matched_idxs = []
  1224. for img_id in range(num_images):
  1225. pos = torch.where(labels[img_id] > 0)[0]
  1226. mask_proposals.append(proposals[img_id][pos])
  1227. pos_matched_idxs.append(matched_idxs[img_id][pos])
  1228. else:
  1229. pos_matched_idxs = None
  1230. if self.mask_roi_pool is not None:
  1231. mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
  1232. mask_features = self.mask_head(mask_features)
  1233. mask_logits = self.mask_predictor(mask_features)
  1234. else:
  1235. raise Exception("Expected mask_roi_pool to be not None")
  1236. loss_mask = {}
  1237. if self.training:
  1238. if targets is None or pos_matched_idxs is None or mask_logits is None:
  1239. raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
  1240. gt_masks = [t["masks"] for t in targets]
  1241. gt_labels = [t["labels"] for t in targets]
  1242. rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
  1243. loss_mask = {"loss_mask": rcnn_loss_mask}
  1244. else:
  1245. labels = [r["labels"] for r in result]
  1246. masks_probs = maskrcnn_inference(mask_logits, labels)
  1247. for mask_prob, r in zip(masks_probs, result):
  1248. r["masks"] = mask_prob
  1249. losses.update(loss_mask)
  1250. # keep none checks in if conditional so torchscript will conditionally
  1251. # compile each branch
  1252. if self.has_keypoint():
  1253. keypoint_proposals = [p["boxes"] for p in result]
  1254. if self.training:
  1255. # during training, only focus on positive boxes
  1256. num_images = len(proposals)
  1257. keypoint_proposals = []
  1258. pos_matched_idxs = []
  1259. if matched_idxs is None:
  1260. raise ValueError("if in trainning, matched_idxs should not be None")
  1261. for img_id in range(num_images):
  1262. pos = torch.where(labels[img_id] > 0)[0]
  1263. keypoint_proposals.append(proposals[img_id][pos])
  1264. pos_matched_idxs.append(matched_idxs[img_id][pos])
  1265. else:
  1266. pos_matched_idxs = None
  1267. keypoint_features = self.line_roi_pool(features, keypoint_proposals, image_shapes)
  1268. keypoint_features = self.line_head(keypoint_features)
  1269. keypoint_logits = self.line_predictor(keypoint_features)
  1270. loss_keypoint = {}
  1271. if self.training:
  1272. if targets is None or pos_matched_idxs is None:
  1273. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  1274. gt_keypoints = [t["keypoints"] for t in targets]
  1275. rcnn_loss_keypoint = keypointrcnn_loss(
  1276. keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
  1277. )
  1278. loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
  1279. else:
  1280. if keypoint_logits is None or keypoint_proposals is None:
  1281. raise ValueError(
  1282. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  1283. )
  1284. keypoints_probs, lines_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
  1285. for masks, kps, r in zip(keypoints_probs, lines_scores, result):
  1286. r["keypoints"] = masks
  1287. r["keypoints_scores"] = kps
  1288. losses.update(loss_keypoint)
  1289. return result, losses
  1290. def check_proposals(self, proposals):
  1291. valid = True
  1292. for proposal in proposals:
  1293. # print(f'per circle_proposal:{circle_proposal.shape}')
  1294. if proposal.shape[0] == 0:
  1295. valid = False
  1296. return valid
  1297. def line_forward1(self, features, image_shapes, line_proposals):
  1298. print(f'line_proposals:{len(line_proposals)}')
  1299. # cs_features= features['0']
  1300. # print(f'features-0:{features['0'].shape}')
  1301. cs_features = self.channel_compress(features['0'])
  1302. filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
  1303. if len(filtered_proposals) > 0:
  1304. filtered_proposals_tensor = torch.cat(filtered_proposals)
  1305. print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1306. line_proposals_tensor = torch.cat(line_proposals)
  1307. print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
  1308. roi_features = features_align(cs_features, line_proposals, image_shapes)
  1309. if roi_features is not None:
  1310. print(f'line_features from align:{roi_features.shape}')
  1311. feature_logits = self.line_head(roi_features)
  1312. print(f'feature_logits from line_head:{feature_logits.shape}')
  1313. return feature_logits
  1314. def line_forward2(self, features, image_shapes, line_proposals):
  1315. print(f'line_proposals:{len(line_proposals)}')
  1316. # cs_features= features['0']
  1317. # print(f'features-0:{features['0'].shape}')
  1318. # cs_features = self.channel_compress(features['0'])
  1319. cs_features=features['0']
  1320. filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
  1321. if len(filtered_proposals) > 0:
  1322. filtered_proposals_tensor = torch.cat(filtered_proposals)
  1323. print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1324. line_proposals=filtered_proposals
  1325. line_proposals_tensor = torch.cat(line_proposals)
  1326. print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
  1327. feature_logits = self.line_head(cs_features)
  1328. print(f'feature_logits from line_head:{feature_logits.shape}')
  1329. roi_features = features_align(feature_logits, line_proposals, image_shapes)
  1330. if roi_features is not None:
  1331. print(f'roi_features from align:{roi_features.shape}')
  1332. return roi_features
  1333. def line_forward3(self, features, image_shapes, line_proposals):
  1334. print(f'line_proposals:{len(line_proposals)}')
  1335. # cs_features= features['0']
  1336. # print(f'features-0:{features['0'].shape}')
  1337. # cs_features = self.channel_compress(features['0'])
  1338. cs_features=features['0']
  1339. # cs_features = features
  1340. # filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
  1341. #
  1342. # if len(filtered_proposals) > 0:
  1343. # filtered_proposals_tensor = torch.cat(filtered_proposals)
  1344. # print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1345. # line_proposals=filtered_proposals
  1346. # line_proposals_tensor = torch.cat(line_proposals)
  1347. # print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
  1348. feature_logits = self.line_predictor(cs_features)
  1349. print(f'feature_logits from line_head:{feature_logits.shape}')
  1350. roi_features = features_align(feature_logits, line_proposals, image_shapes)
  1351. if roi_features is not None:
  1352. print(f'roi_features from align:{roi_features.shape}')
  1353. return roi_features
  1354. def point_forward1(self, features, image_shapes, proposals):
  1355. print(f'point_proposals:{len(proposals)}')
  1356. # cs_features= features['0']
  1357. # print(f'features-0:{features['0'].shape}')
  1358. # cs_features = self.channel_compress(features['0'])
  1359. cs_features=features['0']
  1360. # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
  1361. #
  1362. # if len(filtered_proposals) > 0:
  1363. # filtered_proposals_tensor = torch.cat(filtered_proposals)
  1364. # print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1365. # proposals=filtered_proposals
  1366. # point_proposals_tensor = torch.cat(proposals)
  1367. # print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
  1368. feature_logits = self.point_predictor(cs_features)
  1369. print(f'feature_logits from line_head:{feature_logits.shape}')
  1370. roi_features = features_align(feature_logits, proposals, image_shapes)
  1371. if roi_features is not None:
  1372. print(f'roi_features from align:{roi_features.shape}')
  1373. return roi_features
  1374. def ins_forward1(self, features, image_shapes, proposals):
  1375. print(f'circle_proposals:{len(proposals)}')
  1376. # cs_features= features['0']
  1377. # print(f'features-0:{features['0'].shape}')
  1378. # cs_features = self.channel_compress(features['0'])
  1379. # cs_features=features['0']
  1380. cs_features = features
  1381. # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
  1382. #
  1383. # if len(filtered_proposals) > 0:
  1384. # filtered_proposals_tensor = torch.cat(filtered_proposals)
  1385. # print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1386. # proposals=filtered_proposals
  1387. # point_proposals_tensor = torch.cat(proposals)
  1388. # print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
  1389. feature_logits = self.ins_head(cs_features)
  1390. print(f'feature_logits from circle_head:{feature_logits.shape}')
  1391. roi_features = features_align(feature_logits, proposals, image_shapes)
  1392. if roi_features is not None:
  1393. print(f'roi_features from align:{roi_features.shape}')
  1394. return roi_features
  1395. def arc_forward1(self, features, image_shapes, proposals):
  1396. print(f'arc_proposals:{len(proposals)}')
  1397. # cs_features= features['0']
  1398. # print(f'features-0:{features['0'].shape}')
  1399. # cs_features = self.channel_compress(features['0'])
  1400. # cs_features=features['0']
  1401. cs_features = features
  1402. # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
  1403. #
  1404. # if len(filtered_proposals) > 0:
  1405. # filtered_proposals_tensor = torch.cat(filtered_proposals)
  1406. # print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1407. # proposals=filtered_proposals
  1408. # point_proposals_tensor = torch.cat(proposals)
  1409. # print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
  1410. feature_logits = self.arc_predictor(cs_features)
  1411. print(f'feature_logits from arc_head:{feature_logits.shape}')
  1412. roi_features = features_align(feature_logits, proposals, image_shapes)
  1413. if roi_features is not None:
  1414. print(f'roi_features from align:{roi_features.shape}')
  1415. return roi_features
  1416. import numpy as np
  1417. import torch
  1418. import torch.nn.functional as F
  1419. def compute_arc_equation_loss(arc_equation, proposals, gt_mask_ends, gt_mask_params, arc_pos_matched_idxs,
  1420. gt_labels_all):
  1421. """
  1422. Compute loss between predicted arc equations and ground truth.
  1423. Args:
  1424. arc_equation: list of length B, each Tensor (N_i, 7)
  1425. gt_mask_ends: GT arc end masks (for angle calculation)
  1426. gt_mask_params: list of length B, each numpy array (num_gt, 5)
  1427. arc_pos_matched_idxs: list of length B, each Tensor of indices matching predictions to GT
  1428. gt_labels_all: list of length B, GT labels
  1429. """
  1430. len_proposals = len(proposals) # batch
  1431. device = arc_equation[0].device
  1432. print(
  1433. f'compute_arc_equation_loss line_logits.shape:{arc_equation.shape},len_proposals:{len_proposals},line_matched_idxs:{arc_pos_matched_idxs}')
  1434. print(f'gt_mask_ends:{gt_mask_ends}, gt_mask_params:{gt_mask_params}')
  1435. # gt_angles = []
  1436. # # for gt_mask_end,gt_mask_param in zip(gt_mask_ends, gt_mask_params):
  1437. # # print(f'gt_mask_end:{gt_mask_end}, gt_mask_param:{gt_mask_param}')
  1438. # # gt_angles.append(compute_arc_angles(gt_mask_end,gt_mask_param))
  1439. # for i in range(len(gt_mask_ends)):
  1440. # print(f'gt_mask_end:{gt_mask_ends[i]}, gt_mask_param:{gt_mask_params[i]}')
  1441. # gt_angles.append(compute_arc_angles(gt_mask_ends[i], gt_mask_params[i]))
  1442. # print(f'gt_angles:{gt_angles}')
  1443. print(f'gt_mask_params:{gt_mask_params}')
  1444. print(f'gt_labels_all:{gt_labels_all}')
  1445. print(f'arc_pos_matched_idxs:{arc_pos_matched_idxs}')
  1446. gt_sel_params = []
  1447. gt_sel_angles = []
  1448. for proposals_per_image, gt_ends, gt_params, gt_label, midx in zip(proposals, gt_mask_ends, gt_mask_params,
  1449. gt_labels_all, arc_pos_matched_idxs):
  1450. print(f'line_proposals_per_image:{proposals_per_image.shape}')
  1451. # gt_angle = torch.tensor(gt_angle)
  1452. gt_ends = torch.tensor(gt_ends)
  1453. gt_params = torch.tensor(gt_params)
  1454. if gt_ends.shape[0] > 0:
  1455. # positions = (gt_label == 3).nonzero()[0].item()
  1456. po = gt_ends[midx.cpu()]
  1457. pa = gt_params[midx.cpu()]
  1458. print(f'po:{po},pa:{pa}')
  1459. gt_sel_angles.append(po)
  1460. gt_sel_params.append(pa)
  1461. gt_sel_angles = torch.cat(gt_sel_angles, dim=0)
  1462. gt_sel_params = torch.cat(gt_sel_params, dim=0)
  1463. pred_ends = arc_equation[:, 5:9]
  1464. pred_params = arc_equation[:, :5]
  1465. # print_params(pred_angles, pred_params, gt_sel_angles, gt_sel_params)
  1466. # pred_sin = torch.sin(pred_angles)
  1467. # pred_cos = torch.cos(pred_angles)
  1468. # gt_sin = torch.sin(gt_sel_angles)
  1469. # gt_cos = torch.cos(gt_sel_angles)
  1470. # angle_loss = F.mse_loss(pred_sin, gt_sin) + F.mse_loss(pred_cos, gt_cos)
  1471. param_loss = F.mse_loss(pred_params, gt_sel_params) / 10000
  1472. print("start")
  1473. print_params(pred_ends, gt_sel_angles)
  1474. pred_ends = pred_ends.view(-1, 2, 2)
  1475. print("end")
  1476. print_params(pred_ends, gt_sel_angles)
  1477. ends_loss = F.mse_loss(pred_ends, gt_sel_angles) / 10000
  1478. # print(f'angle_loss:{angle_loss.item()}, param_loss:{param_loss.item()}')
  1479. count = sum(len(sublist) for sublist in proposals)
  1480. total_loss = ((param_loss + ends_loss) / count) if count > 0 else torch.tensor(0.0, device=device,
  1481. dtype=torch.float)
  1482. total_loss = total_loss.to(device)
  1483. ends_loss = ends_loss.to(device)
  1484. param_loss = param_loss.to(device)
  1485. # print(f'total_loss, param_loss, angle_loss: {total_loss.item()}, {param_loss.item()}, {angle_loss.item()}')
  1486. return total_loss, param_loss, ends_loss
  1487. # angle_loss = F.mse_loss(pred_angles, gt_sel_angles)
  1488. # param_loss = F.mse_loss(pred_params.cpu(), gt_sel_params) / 10000
  1489. # print(f'angle_loss:{angle_loss}, param_loss:{param_loss}')
  1490. #
  1491. # count = sum(len(sublist) for sublist in proposals)
  1492. #
  1493. # total_loss = (param_loss + angle_loss) / count if count > 0 else torch.tensor(0.0)
  1494. #
  1495. # # 确保 dtype 和 device
  1496. # total_loss = total_loss.float().to(device)
  1497. # angle_loss = angle_loss.float().to(device)
  1498. # param_loss = param_loss.float().to(device)
  1499. #
  1500. # print(f'total_loss, param_loss, angle_loss:{total_loss, param_loss, angle_loss}')
  1501. #
  1502. # return total_loss, param_loss, angle_loss
  1503. def compute_arc_angles(gt_mask_ends, gt_mask_params):
  1504. """
  1505. 给定椭圆上的一个点,计算其对应的参数角 phi(弧度)。
  1506. Parameters:
  1507. point: tuple or array-like, (x, y)
  1508. ellipse_param: tuple or array-like, (xc, yc, a, b, theta)
  1509. Returns:
  1510. phi: float, in [0, 2*pi)
  1511. """
  1512. results = []
  1513. gt_mask_params_tensor = torch.tensor(gt_mask_params,
  1514. dtype=gt_mask_ends.dtype,
  1515. device=gt_mask_ends.device)
  1516. for ends_img, params_img in zip(gt_mask_ends, gt_mask_params_tensor):
  1517. # print(f'params_img:{params_img}')
  1518. if torch.norm(params_img) < 1e-6: # L2 norm near zero
  1519. results.append(torch.zeros(2, device=params_img.device, dtype=params_img.dtype))
  1520. continue
  1521. x, y = ends_img
  1522. xc, yc, a, b, theta = params_img
  1523. # 1. 平移到中心
  1524. dx = x - xc
  1525. dy = y - yc
  1526. # 2. 逆旋转(旋转 -theta)
  1527. cos_t = torch.cos(theta)
  1528. sin_t = torch.sin(theta)
  1529. X = dx * cos_t + dy * sin_t
  1530. Y = -dx * sin_t + dy * cos_t
  1531. # 3. 归一化到单位圆(除以 a, b)
  1532. cos_phi = X / a
  1533. sin_phi = Y / b
  1534. # 4. 用 atan2 求角度(自动处理象限)
  1535. phi = torch.atan2(sin_phi, cos_phi)
  1536. # 5. 转换到 [0, 2π)
  1537. phi = torch.where(phi < 0, phi + 2 * torch.pi, phi)
  1538. results.append(phi)
  1539. return results