loi_heads.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482
  1. from typing import Dict, List, Optional, Tuple
  2. import matplotlib.pyplot as plt
  3. import torch
  4. import torch.nn.functional as F
  5. import torchvision
  6. # from scipy.optimize import linear_sum_assignment
  7. from torch import nn, Tensor
  8. from libs.vision_libs.ops import boxes as box_ops, roi_align
  9. import libs.vision_libs.models.detection._utils as det_utils
  10. from collections import OrderedDict
  11. from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
  12. lines_point_pair_loss, features_align, line_inference, compute_arc_loss, arc_inference
  13. def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
  14. # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
  15. """
  16. Computes the loss for Faster R-CNN.
  17. Args:
  18. class_logits (Tensor)
  19. box_regression (Tensor)
  20. labels (list[BoxList])
  21. regression_targets (Tensor)
  22. Returns:
  23. classification_loss (Tensor)
  24. box_loss (Tensor)
  25. """
  26. # print(f'compute fastrcnn_loss:{labels}')
  27. labels = torch.cat(labels, dim=0)
  28. regression_targets = torch.cat(regression_targets, dim=0)
  29. classification_loss = F.cross_entropy(class_logits, labels)
  30. # get indices that correspond to the regression targets for
  31. # the corresponding ground truth labels, to be used with
  32. # advanced indexing
  33. sampled_pos_inds_subset = torch.where(labels > 0)[0]
  34. labels_pos = labels[sampled_pos_inds_subset]
  35. N, num_classes = class_logits.shape
  36. box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
  37. box_loss = F.smooth_l1_loss(
  38. box_regression[sampled_pos_inds_subset, labels_pos],
  39. regression_targets[sampled_pos_inds_subset],
  40. beta=1 / 9,
  41. reduction="sum",
  42. )
  43. box_loss = box_loss / labels.numel()
  44. return classification_loss, box_loss
  45. def maskrcnn_inference(x, labels):
  46. # type: (Tensor, List[Tensor]) -> List[Tensor]
  47. """
  48. From the results of the CNN, post process the masks
  49. by taking the mask corresponding to the class with max
  50. probability (which are of fixed size and directly output
  51. by the CNN) and return the masks in the mask field of the BoxList.
  52. Args:
  53. x (Tensor): the mask logits
  54. labels (list[BoxList]): bounding boxes that are used as
  55. reference, one for ech image
  56. Returns:
  57. results (list[BoxList]): one BoxList for each image, containing
  58. the extra field mask
  59. """
  60. mask_prob = x.sigmoid()
  61. # select masks corresponding to the predicted classes
  62. num_masks = x.shape[0]
  63. boxes_per_image = [label.shape[0] for label in labels]
  64. labels = torch.cat(labels)
  65. index = torch.arange(num_masks, device=labels.device)
  66. mask_prob = mask_prob[index, labels][:, None]
  67. mask_prob = mask_prob.split(boxes_per_image, dim=0)
  68. return mask_prob
  69. def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
  70. # type: (Tensor, Tensor, Tensor, int) -> Tensor
  71. """
  72. Given segmentation masks and the bounding boxes corresponding
  73. to the location of the masks in the image, this function
  74. crops and resizes the masks in the position defined by the
  75. boxes. This prepares the masks for them to be fed to the
  76. loss computation as the targets.
  77. """
  78. matched_idxs = matched_idxs.to(boxes)
  79. rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
  80. gt_masks = gt_masks[:, None].to(rois)
  81. return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
  82. def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
  83. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  84. """
  85. Args:
  86. proposals (list[BoxList])
  87. mask_logits (Tensor)
  88. targets (list[BoxList])
  89. Return:
  90. mask_loss (Tensor): scalar tensor containing the loss
  91. """
  92. discretization_size = mask_logits.shape[-1]
  93. labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
  94. mask_targets = [
  95. project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
  96. ]
  97. labels = torch.cat(labels, dim=0)
  98. mask_targets = torch.cat(mask_targets, dim=0)
  99. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  100. # accept empty tensors, so handle it separately
  101. if mask_targets.numel() == 0:
  102. return mask_logits.sum() * 0
  103. mask_loss = F.binary_cross_entropy_with_logits(
  104. mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
  105. )
  106. return mask_loss
  107. def keypoints_to_heatmap(keypoints, rois, heatmap_size):
  108. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  109. offset_x = rois[:, 0]
  110. offset_y = rois[:, 1]
  111. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  112. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  113. offset_x = offset_x[:, None]
  114. offset_y = offset_y[:, None]
  115. scale_x = scale_x[:, None]
  116. scale_y = scale_y[:, None]
  117. x = keypoints[..., 0]
  118. y = keypoints[..., 1]
  119. x_boundary_inds = x == rois[:, 2][:, None]
  120. y_boundary_inds = y == rois[:, 3][:, None]
  121. x = (x - offset_x) * scale_x
  122. x = x.floor().long()
  123. y = (y - offset_y) * scale_y
  124. y = y.floor().long()
  125. x[x_boundary_inds] = heatmap_size - 1
  126. y[y_boundary_inds] = heatmap_size - 1
  127. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  128. vis = keypoints[..., 2] > 0
  129. valid = (valid_loc & vis).long()
  130. lin_ind = y * heatmap_size + x
  131. heatmaps = lin_ind * valid
  132. return heatmaps, valid
  133. def _onnx_heatmaps_to_keypoints(
  134. maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
  135. ):
  136. num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
  137. width_correction = widths_i / roi_map_width
  138. height_correction = heights_i / roi_map_height
  139. roi_map = F.interpolate(
  140. maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
  141. )[:, 0]
  142. w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
  143. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  144. x_int = pos % w
  145. y_int = (pos - x_int) // w
  146. x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
  147. dtype=torch.float32
  148. )
  149. y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
  150. dtype=torch.float32
  151. )
  152. xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
  153. xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
  154. xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
  155. xy_preds_i = torch.stack(
  156. [
  157. xy_preds_i_0.to(dtype=torch.float32),
  158. xy_preds_i_1.to(dtype=torch.float32),
  159. xy_preds_i_2.to(dtype=torch.float32),
  160. ],
  161. 0,
  162. )
  163. # TODO: simplify when indexing without rank will be supported by ONNX
  164. base = num_keypoints * num_keypoints + num_keypoints + 1
  165. ind = torch.arange(num_keypoints)
  166. ind = ind.to(dtype=torch.int64) * base
  167. end_scores_i = (
  168. roi_map.index_select(1, y_int.to(dtype=torch.int64))
  169. .index_select(2, x_int.to(dtype=torch.int64))
  170. .view(-1)
  171. .index_select(0, ind.to(dtype=torch.int64))
  172. )
  173. return xy_preds_i, end_scores_i
  174. @torch.jit._script_if_tracing
  175. def _onnx_heatmaps_to_keypoints_loop(
  176. maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
  177. ):
  178. xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  179. end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  180. for i in range(int(rois.size(0))):
  181. xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
  182. maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
  183. )
  184. xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
  185. end_scores = torch.cat(
  186. (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
  187. )
  188. return xy_preds, end_scores
  189. def heatmaps_to_keypoints(maps, rois):
  190. """Extract predicted keypoint locations from heatmaps. Output has shape
  191. (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
  192. for each keypoint.
  193. """
  194. # This function converts a discrete image coordinate in a HEATMAP_SIZE x
  195. # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
  196. # consistency with keypoints_to_heatmap_labels by using the conversion from
  197. # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
  198. # continuous coordinate.
  199. offset_x = rois[:, 0]
  200. offset_y = rois[:, 1]
  201. widths = rois[:, 2] - rois[:, 0]
  202. heights = rois[:, 3] - rois[:, 1]
  203. widths = widths.clamp(min=1)
  204. heights = heights.clamp(min=1)
  205. widths_ceil = widths.ceil()
  206. heights_ceil = heights.ceil()
  207. num_keypoints = maps.shape[1]
  208. if torchvision._is_tracing():
  209. xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
  210. maps,
  211. rois,
  212. widths_ceil,
  213. heights_ceil,
  214. widths,
  215. heights,
  216. offset_x,
  217. offset_y,
  218. torch.scalar_tensor(num_keypoints, dtype=torch.int64),
  219. )
  220. return xy_preds.permute(0, 2, 1), end_scores
  221. xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
  222. end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
  223. for i in range(len(rois)):
  224. roi_map_width = int(widths_ceil[i].item())
  225. roi_map_height = int(heights_ceil[i].item())
  226. width_correction = widths[i] / roi_map_width
  227. height_correction = heights[i] / roi_map_height
  228. roi_map = F.interpolate(
  229. maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
  230. )[:, 0]
  231. # roi_map_probs = scores_to_probs(roi_map.copy())
  232. w = roi_map.shape[2]
  233. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  234. x_int = pos % w
  235. y_int = torch.div(pos - x_int, w, rounding_mode="floor")
  236. # assert (roi_map_probs[k, y_int, x_int] ==
  237. # roi_map_probs[k, :, :].max())
  238. x = (x_int.float() + 0.5) * width_correction
  239. y = (y_int.float() + 0.5) * height_correction
  240. xy_preds[i, 0, :] = x + offset_x[i]
  241. xy_preds[i, 1, :] = y + offset_y[i]
  242. xy_preds[i, 2, :] = 1
  243. end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
  244. return xy_preds.permute(0, 2, 1), end_scores
  245. def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
  246. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  247. N, K, H, W = keypoint_logits.shape
  248. if H != W:
  249. raise ValueError(
  250. f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  251. )
  252. discretization_size = H
  253. heatmaps = []
  254. valid = []
  255. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
  256. kp = gt_kp_in_image[midx]
  257. heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
  258. heatmaps.append(heatmaps_per_image.view(-1))
  259. valid.append(valid_per_image.view(-1))
  260. keypoint_targets = torch.cat(heatmaps, dim=0)
  261. valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  262. valid = torch.where(valid)[0]
  263. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  264. # accept empty tensors, so handle it sepaartely
  265. if keypoint_targets.numel() == 0 or len(valid) == 0:
  266. return keypoint_logits.sum() * 0
  267. keypoint_logits = keypoint_logits.view(N * K, H * W)
  268. keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
  269. return keypoint_loss
  270. def keypointrcnn_inference(x, boxes):
  271. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  272. kp_probs = []
  273. kp_scores = []
  274. boxes_per_image = [box.size(0) for box in boxes]
  275. x2 = x.split(boxes_per_image, dim=0)
  276. for xx, bb in zip(x2, boxes):
  277. kp_prob, scores = heatmaps_to_keypoints(xx, bb)
  278. kp_probs.append(kp_prob)
  279. kp_scores.append(scores)
  280. return kp_probs, kp_scores
  281. def _onnx_expand_boxes(boxes, scale):
  282. # type: (Tensor, float) -> Tensor
  283. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  284. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  285. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  286. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  287. w_half = w_half.to(dtype=torch.float32) * scale
  288. h_half = h_half.to(dtype=torch.float32) * scale
  289. boxes_exp0 = x_c - w_half
  290. boxes_exp1 = y_c - h_half
  291. boxes_exp2 = x_c + w_half
  292. boxes_exp3 = y_c + h_half
  293. boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
  294. return boxes_exp
  295. # the next two functions should be merged inside Masker
  296. # but are kept here for the moment while we need them
  297. # temporarily for paste_mask_in_image
  298. def expand_boxes(boxes, scale):
  299. # type: (Tensor, float) -> Tensor
  300. if torchvision._is_tracing():
  301. return _onnx_expand_boxes(boxes, scale)
  302. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  303. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  304. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  305. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  306. w_half *= scale
  307. h_half *= scale
  308. boxes_exp = torch.zeros_like(boxes)
  309. boxes_exp[:, 0] = x_c - w_half
  310. boxes_exp[:, 2] = x_c + w_half
  311. boxes_exp[:, 1] = y_c - h_half
  312. boxes_exp[:, 3] = y_c + h_half
  313. return boxes_exp
  314. @torch.jit.unused
  315. def expand_masks_tracing_scale(M, padding):
  316. # type: (int, int) -> float
  317. return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
  318. def expand_masks(mask, padding):
  319. # type: (Tensor, int) -> Tuple[Tensor, float]
  320. M = mask.shape[-1]
  321. if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
  322. scale = expand_masks_tracing_scale(M, padding)
  323. else:
  324. scale = float(M + 2 * padding) / M
  325. padded_mask = F.pad(mask, (padding,) * 4)
  326. return padded_mask, scale
  327. def paste_mask_in_image(mask, box, im_h, im_w):
  328. # type: (Tensor, Tensor, int, int) -> Tensor
  329. TO_REMOVE = 1
  330. w = int(box[2] - box[0] + TO_REMOVE)
  331. h = int(box[3] - box[1] + TO_REMOVE)
  332. w = max(w, 1)
  333. h = max(h, 1)
  334. # Set shape to [batchxCxHxW]
  335. mask = mask.expand((1, 1, -1, -1))
  336. # Resize mask
  337. mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
  338. mask = mask[0][0]
  339. im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
  340. x_0 = max(box[0], 0)
  341. x_1 = min(box[2] + 1, im_w)
  342. y_0 = max(box[1], 0)
  343. y_1 = min(box[3] + 1, im_h)
  344. im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  345. return im_mask
  346. def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
  347. one = torch.ones(1, dtype=torch.int64)
  348. zero = torch.zeros(1, dtype=torch.int64)
  349. w = box[2] - box[0] + one
  350. h = box[3] - box[1] + one
  351. w = torch.max(torch.cat((w, one)))
  352. h = torch.max(torch.cat((h, one)))
  353. # Set shape to [batchxCxHxW]
  354. mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
  355. # Resize mask
  356. mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
  357. mask = mask[0][0]
  358. x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
  359. x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
  360. y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
  361. y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
  362. unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  363. # TODO : replace below with a dynamic padding when support is added in ONNX
  364. # pad y
  365. zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
  366. zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
  367. concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
  368. # pad x
  369. zeros_x0 = torch.zeros(concat_0.size(0), x_0)
  370. zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
  371. im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
  372. return im_mask
  373. @torch.jit._script_if_tracing
  374. def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
  375. res_append = torch.zeros(0, im_h, im_w)
  376. for i in range(masks.size(0)):
  377. mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
  378. mask_res = mask_res.unsqueeze(0)
  379. res_append = torch.cat((res_append, mask_res))
  380. return res_append
  381. def paste_masks_in_image(masks, boxes, img_shape, padding=1):
  382. # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
  383. masks, scale = expand_masks(masks, padding=padding)
  384. boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
  385. im_h, im_w = img_shape
  386. if torchvision._is_tracing():
  387. return _onnx_paste_masks_in_image_loop(
  388. masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
  389. )[:, None]
  390. res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
  391. if len(res) > 0:
  392. ret = torch.stack(res, dim=0)[:, None]
  393. else:
  394. ret = masks.new_empty((0, 1, im_h, im_w))
  395. return ret
  396. class RoIHeads(nn.Module):
  397. __annotations__ = {
  398. "box_coder": det_utils.BoxCoder,
  399. "proposal_matcher": det_utils.Matcher,
  400. "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
  401. }
  402. def __init__(
  403. self,
  404. box_roi_pool,
  405. box_head,
  406. box_predictor,
  407. # Faster R-CNN training
  408. fg_iou_thresh,
  409. bg_iou_thresh,
  410. batch_size_per_image,
  411. positive_fraction,
  412. bbox_reg_weights,
  413. # Faster R-CNN inference
  414. score_thresh,
  415. nms_thresh,
  416. detections_per_img,
  417. # Line
  418. line_roi_pool=None,
  419. line_head=None,
  420. line_predictor=None,
  421. # point parameters
  422. point_roi_pool=None,
  423. point_head=None,
  424. point_predictor=None,
  425. # arc parameters
  426. arc_roi_pool=None,
  427. arc_head=None,
  428. arc_predictor=None,
  429. # Mask
  430. mask_roi_pool=None,
  431. mask_head=None,
  432. mask_predictor=None,
  433. keypoint_roi_pool=None,
  434. keypoint_head=None,
  435. keypoint_predictor=None,
  436. detect_point=True,
  437. detect_line=True,
  438. detect_arc=False,
  439. ):
  440. super().__init__()
  441. self.box_similarity = box_ops.box_iou
  442. # assign ground-truth boxes for each proposal
  443. self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
  444. self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
  445. if bbox_reg_weights is None:
  446. bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
  447. self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
  448. self.box_roi_pool = box_roi_pool
  449. self.box_head = box_head
  450. self.box_predictor = box_predictor
  451. self.score_thresh = score_thresh
  452. self.nms_thresh = nms_thresh
  453. self.detections_per_img = detections_per_img
  454. self.line_roi_pool = line_roi_pool
  455. self.line_head = line_head
  456. self.line_predictor = line_predictor
  457. self.point_roi_pool = point_roi_pool
  458. self.point_head = point_head
  459. self.point_predictor = point_predictor
  460. self.arc_roi_pool = arc_roi_pool
  461. self.arc_head = arc_head
  462. self.arc_predictor = arc_predictor
  463. self.mask_roi_pool = mask_roi_pool
  464. self.mask_head = mask_head
  465. self.mask_predictor = mask_predictor
  466. self.keypoint_roi_pool = keypoint_roi_pool
  467. self.keypoint_head = keypoint_head
  468. self.keypoint_predictor = keypoint_predictor
  469. self.detect_point =detect_point
  470. self.detect_line =detect_line
  471. self.detect_arc =detect_arc
  472. self.channel_compress = nn.Sequential(
  473. nn.Conv2d(256, 8, kernel_size=1),
  474. nn.BatchNorm2d(8),
  475. nn.ReLU(inplace=True)
  476. )
  477. def has_mask(self):
  478. if self.mask_roi_pool is None:
  479. return False
  480. if self.mask_head is None:
  481. return False
  482. if self.mask_predictor is None:
  483. return False
  484. return True
  485. def has_keypoint(self):
  486. if self.keypoint_roi_pool is None:
  487. return False
  488. if self.keypoint_head is None:
  489. return False
  490. if self.keypoint_predictor is None:
  491. return False
  492. return True
  493. def has_line(self):
  494. # if self.line_roi_pool is None:
  495. # return False
  496. if self.line_head is None:
  497. return False
  498. # if self.line_predictor is None:
  499. # return False
  500. return True
  501. def has_point(self):
  502. # if self.line_roi_pool is None:
  503. # return False
  504. if self.point_head is None:
  505. return False
  506. # if self.line_predictor is None:
  507. # return False
  508. return True
  509. def has_arc(self):
  510. # if self.line_roi_pool is None:
  511. # return False
  512. if self.arc_head is None:
  513. return False
  514. # if self.line_predictor is None:
  515. # return False
  516. return True
  517. def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
  518. # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  519. matched_idxs = []
  520. labels = []
  521. for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
  522. if gt_boxes_in_image.numel() == 0:
  523. # Background image
  524. device = proposals_in_image.device
  525. clamped_matched_idxs_in_image = torch.zeros(
  526. (proposals_in_image.shape[0],), dtype=torch.int64, device=device
  527. )
  528. labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
  529. else:
  530. # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
  531. match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
  532. matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
  533. clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
  534. labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
  535. labels_in_image = labels_in_image.to(dtype=torch.int64)
  536. # Label background (below the low threshold)
  537. bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
  538. labels_in_image[bg_inds] = 0
  539. # Label ignore proposals (between low and high thresholds)
  540. ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
  541. labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
  542. matched_idxs.append(clamped_matched_idxs_in_image)
  543. labels.append(labels_in_image)
  544. return matched_idxs, labels
  545. def subsample(self, labels):
  546. # type: (List[Tensor]) -> List[Tensor]
  547. sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  548. sampled_inds = []
  549. for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
  550. img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
  551. sampled_inds.append(img_sampled_inds)
  552. return sampled_inds
  553. def add_gt_proposals(self, proposals, gt_boxes):
  554. # type: (List[Tensor], List[Tensor]) -> List[Tensor]
  555. proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
  556. return proposals
  557. def check_targets(self, targets):
  558. # type: (Optional[List[Dict[str, Tensor]]]) -> None
  559. if targets is None:
  560. raise ValueError("targets should not be None")
  561. if not all(["boxes" in t for t in targets]):
  562. raise ValueError("Every element of targets should have a boxes key")
  563. if not all(["labels" in t for t in targets]):
  564. raise ValueError("Every element of targets should have a labels key")
  565. if self.has_mask():
  566. if not all(["masks" in t for t in targets]):
  567. raise ValueError("Every element of targets should have a masks key")
  568. def select_training_samples(
  569. self,
  570. proposals, # type: List[Tensor]
  571. targets, # type: Optional[List[Dict[str, Tensor]]]
  572. ):
  573. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
  574. self.check_targets(targets)
  575. if targets is None:
  576. raise ValueError("targets should not be None")
  577. dtype = proposals[0].dtype
  578. device = proposals[0].device
  579. gt_boxes = [t["boxes"].to(dtype) for t in targets]
  580. gt_labels = [t["labels"] for t in targets]
  581. # append ground-truth bboxes to propos
  582. proposals = self.add_gt_proposals(proposals, gt_boxes)
  583. # get matching gt indices for each proposal
  584. matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
  585. # sample a fixed proportion of positive-negative proposals
  586. sampled_inds = self.subsample(labels)
  587. matched_gt_boxes = []
  588. num_images = len(proposals)
  589. for img_id in range(num_images):
  590. img_sampled_inds = sampled_inds[img_id]
  591. proposals[img_id] = proposals[img_id][img_sampled_inds]
  592. labels[img_id] = labels[img_id][img_sampled_inds]
  593. matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
  594. gt_boxes_in_image = gt_boxes[img_id]
  595. if gt_boxes_in_image.numel() == 0:
  596. gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
  597. matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
  598. regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
  599. return proposals, matched_idxs, labels, regression_targets
  600. def postprocess_detections(
  601. self,
  602. class_logits, # type: Tensor
  603. box_regression, # type: Tensor
  604. proposals, # type: List[Tensor]
  605. image_shapes, # type: List[Tuple[int, int]]
  606. ):
  607. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
  608. device = class_logits.device
  609. num_classes = class_logits.shape[-1]
  610. boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
  611. pred_boxes = self.box_coder.decode(box_regression, proposals)
  612. pred_scores = F.softmax(class_logits, -1)
  613. pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
  614. pred_scores_list = pred_scores.split(boxes_per_image, 0)
  615. all_boxes = []
  616. all_scores = []
  617. all_labels = []
  618. for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
  619. boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
  620. # create labels for each prediction
  621. labels = torch.arange(num_classes, device=device)
  622. labels = labels.view(1, -1).expand_as(scores)
  623. # remove predictions with the background label
  624. boxes = boxes[:, 1:]
  625. scores = scores[:, 1:]
  626. labels = labels[:, 1:]
  627. # batch everything, by making every class prediction be a separate instance
  628. boxes = boxes.reshape(-1, 4)
  629. scores = scores.reshape(-1)
  630. labels = labels.reshape(-1)
  631. # remove low scoring boxes
  632. inds = torch.where(scores > self.score_thresh)[0]
  633. boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
  634. # remove empty boxes
  635. keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
  636. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  637. # non-maximum suppression, independently done per class
  638. keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
  639. # keep only topk scoring predictions
  640. keep = keep[: self.detections_per_img]
  641. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  642. all_boxes.append(boxes)
  643. all_scores.append(scores)
  644. all_labels.append(labels)
  645. return all_boxes, all_scores, all_labels
  646. def forward(
  647. self,
  648. features, # type: Dict[str, Tensor]
  649. proposals, # type: List[Tensor]
  650. image_shapes, # type: List[Tuple[int, int]]
  651. targets=None, # type: Optional[List[Dict[str, Tensor]]]
  652. ):
  653. # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
  654. """
  655. Args:
  656. features (List[Tensor])
  657. proposals (List[Tensor[N, 4]])
  658. image_shapes (List[Tuple[H, W]])
  659. targets (List[Dict])
  660. """
  661. print(f'roihead forward!!!')
  662. if targets is not None:
  663. for t in targets:
  664. # TODO: https://github.com/pytorch/pytorch/issues/26731
  665. floating_point_types = (torch.float, torch.double, torch.half)
  666. if not t["boxes"].dtype in floating_point_types:
  667. raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
  668. if not t["labels"].dtype == torch.int64:
  669. raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
  670. if self.has_keypoint():
  671. if not t["keypoints"].dtype == torch.float32:
  672. raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
  673. if self.training:
  674. proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  675. else:
  676. if targets is not None:
  677. proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  678. else:
  679. labels = None
  680. regression_targets = None
  681. matched_idxs = None
  682. device=features['0'].device
  683. box_features = self.box_roi_pool(features, proposals, image_shapes)
  684. box_features = self.box_head(box_features)
  685. class_logits, box_regression = self.box_predictor(box_features)
  686. result: List[Dict[str, torch.Tensor]] = []
  687. losses = {}
  688. # _, C, H, W = features['0'].shape # 忽略 batch_size,因为我们只关心 C, H, W
  689. if self.training:
  690. if labels is None:
  691. raise ValueError("labels cannot be None")
  692. if regression_targets is None:
  693. raise ValueError("regression_targets cannot be None")
  694. print(f'boxes compute losses')
  695. loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  696. losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  697. else:
  698. if targets is not None:
  699. loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  700. losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  701. boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals,
  702. image_shapes)
  703. num_images = len(boxes)
  704. for i in range(num_images):
  705. result.append(
  706. {
  707. "boxes": boxes[i],
  708. "labels": labels[i],
  709. "scores": scores[i],
  710. }
  711. )
  712. if self.has_line() and self.detect_line:
  713. print(f'roi_heads forward has_line()!!!!')
  714. # print(f'labels:{labels}')
  715. line_proposals = [p["boxes"] for p in result]
  716. point_proposals = [p["boxes"] for p in result]
  717. print(f'boxes_proposals:{len(line_proposals)}')
  718. # if line_proposals is None or len(line_proposals) == 0:
  719. # # 返回空特征或者跳过该部分计算
  720. # return torch.empty(0, C, H, W).to(features['0'].device)
  721. if self.training:
  722. # during training, only focus on positive boxes
  723. num_images = len(proposals)
  724. print(f'num_images:{num_images}')
  725. line_proposals = []
  726. point_proposals = []
  727. arc_proposals = []
  728. pos_matched_idxs = []
  729. line_pos_matched_idxs = []
  730. point_pos_matched_idxs = []
  731. if matched_idxs is None:
  732. raise ValueError("if in trainning, matched_idxs should not be None")
  733. for img_id in range(num_images):
  734. pos = torch.where(labels[img_id] > 0)[0]
  735. line_pos=torch.where(labels[img_id] ==2)[0]
  736. # point_pos=torch.where(labels[img_id] ==1)[0]
  737. line_proposals.append(proposals[img_id][line_pos])
  738. # point_proposals.append(proposals[img_id][point_pos])
  739. line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
  740. # point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
  741. # pos_matched_idxs.append(matched_idxs[img_id][pos])
  742. else:
  743. if targets is not None:
  744. pos_matched_idxs = []
  745. num_images = len(proposals)
  746. line_proposals = []
  747. line_pos_matched_idxs = []
  748. print(f'val num_images:{num_images}')
  749. if matched_idxs is None:
  750. raise ValueError("if in trainning, matched_idxs should not be None")
  751. for img_id in range(num_images):
  752. # pos = torch.where(labels[img_id] > 0)[0]
  753. line_pos = torch.where(labels[img_id] == 2)[0]
  754. line_proposals.append(proposals[img_id][line_pos])
  755. line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
  756. else:
  757. pos_matched_idxs = None
  758. feature_logits = self.line_forward3(features, image_shapes, line_proposals)
  759. loss_line = None
  760. loss_line_iou =None
  761. if self.training:
  762. if targets is None or pos_matched_idxs is None:
  763. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  764. gt_lines = [t["lines"] for t in targets if "lines" in t]
  765. # print(f'gt_lines:{gt_lines[0].shape}')
  766. h, w = targets[0]["img_size"]
  767. img_size = h
  768. gt_lines_tensor=torch.zeros(0,0)
  769. if len(gt_lines)>0:
  770. gt_lines_tensor = torch.cat(gt_lines)
  771. print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
  772. if gt_lines_tensor.shape[0]>0 :
  773. print(f'start to lines_point_pair_loss')
  774. loss_line = lines_point_pair_loss(
  775. feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
  776. )
  777. loss_line_iou = line_iou_loss(feature_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
  778. if loss_line is None:
  779. print(f'loss_line is None111')
  780. loss_line = torch.tensor(0.0, device=device)
  781. if loss_line_iou is None:
  782. print(f'loss_line_iou is None111')
  783. loss_line_iou = torch.tensor(0.0, device=device)
  784. loss_line = {"loss_line": loss_line}
  785. loss_line_iou = {'loss_line_iou': loss_line_iou}
  786. else:
  787. if targets is not None:
  788. h, w = targets[0]["img_size"]
  789. img_size = h
  790. gt_lines = [t["lines"] for t in targets if "lines" in t]
  791. gt_lines_tensor = torch.zeros(0, 0)
  792. if len(gt_lines)>0:
  793. gt_lines_tensor = torch.cat(gt_lines)
  794. if gt_lines_tensor.shape[0] > 0 and feature_logits is not None:
  795. loss_line = lines_point_pair_loss(
  796. feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
  797. )
  798. print(f'compute_line_loss:{loss_line}')
  799. loss_line_iou = line_iou_loss(feature_logits , line_proposals, gt_lines, line_pos_matched_idxs,
  800. img_size)
  801. if loss_line is None:
  802. print(f'loss_line is None')
  803. loss_line=torch.tensor(0.0,device=device)
  804. if loss_line_iou is None:
  805. print(f'loss_line_iou is None')
  806. loss_line_iou=torch.tensor(0.0,device=device)
  807. loss_line = {"loss_line": loss_line}
  808. loss_line_iou = {'loss_line_iou': loss_line_iou}
  809. else:
  810. loss_line = {}
  811. loss_line_iou = {}
  812. if feature_logits is None or line_proposals is None:
  813. raise ValueError(
  814. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  815. )
  816. if feature_logits is not None:
  817. lines_probs, lines_scores = line_inference(feature_logits,line_proposals)
  818. for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
  819. r["lines"] = keypoint_prob
  820. r["lines_scores"] = kps
  821. print(f'loss_line11111:{loss_line}')
  822. losses.update(loss_line)
  823. losses.update(loss_line_iou)
  824. print(f'losses:{losses}')
  825. if self.has_point() and self.detect_point:
  826. print(f'roi_heads forward has_point()!!!!')
  827. # print(f'labels:{labels}')
  828. point_proposals = [p["boxes"] for p in result]
  829. print(f'boxes_proposals:{len(point_proposals)}')
  830. # if line_proposals is None or len(line_proposals) == 0:
  831. # # 返回空特征或者跳过该部分计算
  832. # return torch.empty(0, C, H, W).to(features['0'].device)
  833. if self.training:
  834. # during training, only focus on positive boxes
  835. num_images = len(proposals)
  836. print(f'num_images:{num_images}')
  837. point_proposals = []
  838. point_pos_matched_idxs = []
  839. if matched_idxs is None:
  840. raise ValueError("if in trainning, matched_idxs should not be None")
  841. for img_id in range(num_images):
  842. point_pos=torch.where(labels[img_id] ==1)[0]
  843. point_proposals.append(proposals[img_id][point_pos])
  844. point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
  845. else:
  846. if targets is not None:
  847. num_images = len(proposals)
  848. point_proposals = []
  849. point_pos_matched_idxs = []
  850. print(f'val num_images:{num_images}')
  851. if matched_idxs is None:
  852. raise ValueError("if in trainning, matched_idxs should not be None")
  853. for img_id in range(num_images):
  854. point_pos = torch.where(labels[img_id] == 1)[0]
  855. point_proposals.append(proposals[img_id][point_pos])
  856. point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
  857. else:
  858. pos_matched_idxs = None
  859. feature_logits = self.point_forward1(features, image_shapes, point_proposals)
  860. loss_point=None
  861. if self.training:
  862. if targets is None or point_pos_matched_idxs is None:
  863. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  864. gt_points = [t["points"] for t in targets if "points" in t]
  865. print(f'gt_points:{gt_points[0].shape}')
  866. h, w = targets[0]["img_size"]
  867. img_size = h
  868. gt_points_tensor = torch.zeros(0, 0)
  869. if len(gt_points) > 0:
  870. gt_points_tensor = torch.cat(gt_points)
  871. print(f'gt_points_tensor:{gt_points_tensor.shape}')
  872. if gt_points_tensor.shape[0] > 0:
  873. print(f'start to compute point_loss')
  874. loss_point=compute_point_loss(feature_logits,point_proposals,gt_points,point_pos_matched_idxs)
  875. if loss_point is None:
  876. print(f'loss_point is None111')
  877. loss_point = torch.tensor(0.0, device=device)
  878. loss_point = {"loss_point": loss_point}
  879. else:
  880. if targets is not None:
  881. h, w = targets[0]["img_size"]
  882. img_size = h
  883. gt_points = [t["points"] for t in targets if "points" in t]
  884. gt_points_tensor = torch.zeros(0, 0)
  885. if len(gt_points) > 0:
  886. gt_points_tensor = torch.cat(gt_points)
  887. print(f'gt_points_tensor:{gt_points_tensor.shape}')
  888. if gt_points_tensor.shape[0] > 0:
  889. print(f'start to compute point_loss')
  890. loss_point = compute_point_loss(feature_logits, point_proposals, gt_points,
  891. point_pos_matched_idxs)
  892. if loss_point is None:
  893. print(f'loss_point is None111')
  894. loss_point = torch.tensor(0.0, device=device)
  895. loss_point = {"loss_point": loss_point}
  896. else:
  897. loss_point = {}
  898. if feature_logits is None or point_proposals is None:
  899. raise ValueError(
  900. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  901. )
  902. if feature_logits is not None:
  903. points_probs, points_scores = point_inference(feature_logits,point_proposals)
  904. for keypoint_prob, kps, r in zip(points_probs, points_scores, result):
  905. r["points"] = keypoint_prob
  906. r["points_scores"] = kps
  907. print(f'loss_point:{loss_point}')
  908. losses.update(loss_point)
  909. print(f'losses:{losses}')
  910. if self.has_arc() and self.detect_arc:
  911. print(f'roi_heads forward has_arc()!!!!')
  912. # print(f'labels:{labels}')
  913. arc_proposals = [p["boxes"] for p in result]
  914. print(f'boxes_proposals:{len(arc_proposals)}')
  915. print(f'boxes_proposals:{len(arc_proposals)}')
  916. # if line_proposals is None or len(line_proposals) == 0:
  917. # # 返回空特征或者跳过该部分计算
  918. # return torch.empty(0, C, H, W).to(features['0'].device)
  919. if self.training:
  920. # during training, only focus on positive boxes
  921. num_images = len(proposals)
  922. print(f'num_images:{num_images}')
  923. arc_proposals = []
  924. arc_pos_matched_idxs = []
  925. if matched_idxs is None:
  926. raise ValueError("if in trainning, matched_idxs should not be None")
  927. for img_id in range(num_images):
  928. arc_pos=torch.where(labels[img_id] ==3)[0]
  929. arc_proposals.append(proposals[img_id][arc_pos])
  930. arc_pos_matched_idxs.append(matched_idxs[img_id][arc_pos])
  931. else:
  932. if targets is not None:
  933. num_images = len(proposals)
  934. arc_proposals = []
  935. arc_pos_matched_idxs = []
  936. print(f'val num_images:{num_images}')
  937. if matched_idxs is None:
  938. raise ValueError("if in trainning, matched_idxs should not be None")
  939. for img_id in range(num_images):
  940. arc_pos = torch.where(labels[img_id] == 3)[0]
  941. arc_proposals.append(proposals[img_id][arc_pos])
  942. arc_pos_matched_idxs.append(matched_idxs[img_id][arc_pos])
  943. else:
  944. arc_pos_matched_idxs = None
  945. feature_logits = self.arc_forward1(features, image_shapes, arc_proposals)
  946. loss_arc=None
  947. if self.training:
  948. if targets is None or arc_pos_matched_idxs is None:
  949. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  950. gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
  951. print(f'gt_arcs:{gt_arcs[0].shape}')
  952. h, w = targets[0]["img_size"]
  953. img_size = h
  954. # gt_arcs_tensor = torch.zeros(0, 0)
  955. # if len(gt_arcs) > 0:
  956. # gt_arcs_tensor = torch.cat(gt_arcs)
  957. # print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
  958. #
  959. # if gt_arcs_tensor.shape[0] > 0:
  960. # print(f'start to compute point_loss')
  961. if len(gt_arcs) > 0 and feature_logits is not None:
  962. loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
  963. if loss_arc is None:
  964. print(f'loss_arc is None111')
  965. loss_arc = torch.tensor(0.0, device=device)
  966. loss_arc = {"loss_arc": loss_arc}
  967. else:
  968. if targets is not None:
  969. h, w = targets[0]["img_size"]
  970. img_size = h
  971. gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
  972. print(f'gt_arcs:{gt_arcs[0].shape}')
  973. h, w = targets[0]["img_size"]
  974. img_size = h
  975. # gt_arcs_tensor = torch.zeros(0, 0)
  976. # if len(gt_arcs) > 0:
  977. # gt_arcs_tensor = torch.cat(gt_arcs)
  978. # print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
  979. # if gt_arcs_tensor.shape[0] > 0 and feature_logits is not None:
  980. # print(f'start to compute arc_loss')
  981. if len(gt_arcs) > 0 and feature_logits is not None:
  982. print(f'start to compute arc_loss')
  983. loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
  984. if loss_arc is None:
  985. print(f'loss_arc is None111')
  986. loss_arc = torch.tensor(0.0, device=device)
  987. loss_arc = {"loss_arc": loss_arc}
  988. else:
  989. loss_arc = {}
  990. if feature_logits is None or arc_proposals is None:
  991. raise ValueError(
  992. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  993. )
  994. if feature_logits is not None:
  995. arcs_probs, arcs_scores = arc_inference(feature_logits,arc_proposals)
  996. for keypoint_prob, kps, r in zip(arcs_probs, arcs_scores, result):
  997. r["arcs"] = keypoint_prob
  998. r["arcs_scores"] = kps
  999. # print(f'loss_point:{loss_point}')
  1000. losses.update(loss_arc)
  1001. print(f'losses:{losses}')
  1002. if self.has_mask():
  1003. mask_proposals = [p["boxes"] for p in result]
  1004. if self.training:
  1005. if matched_idxs is None:
  1006. raise ValueError("if in training, matched_idxs should not be None")
  1007. # during training, only focus on positive boxes
  1008. num_images = len(proposals)
  1009. mask_proposals = []
  1010. pos_matched_idxs = []
  1011. for img_id in range(num_images):
  1012. pos = torch.where(labels[img_id] > 0)[0]
  1013. mask_proposals.append(proposals[img_id][pos])
  1014. pos_matched_idxs.append(matched_idxs[img_id][pos])
  1015. else:
  1016. pos_matched_idxs = None
  1017. if self.mask_roi_pool is not None:
  1018. mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
  1019. mask_features = self.mask_head(mask_features)
  1020. mask_logits = self.mask_predictor(mask_features)
  1021. else:
  1022. raise Exception("Expected mask_roi_pool to be not None")
  1023. loss_mask = {}
  1024. if self.training:
  1025. if targets is None or pos_matched_idxs is None or mask_logits is None:
  1026. raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
  1027. gt_masks = [t["masks"] for t in targets]
  1028. gt_labels = [t["labels"] for t in targets]
  1029. rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
  1030. loss_mask = {"loss_mask": rcnn_loss_mask}
  1031. else:
  1032. labels = [r["labels"] for r in result]
  1033. masks_probs = maskrcnn_inference(mask_logits, labels)
  1034. for mask_prob, r in zip(masks_probs, result):
  1035. r["masks"] = mask_prob
  1036. losses.update(loss_mask)
  1037. # keep none checks in if conditional so torchscript will conditionally
  1038. # compile each branch
  1039. if self.has_keypoint():
  1040. keypoint_proposals = [p["boxes"] for p in result]
  1041. if self.training:
  1042. # during training, only focus on positive boxes
  1043. num_images = len(proposals)
  1044. keypoint_proposals = []
  1045. pos_matched_idxs = []
  1046. if matched_idxs is None:
  1047. raise ValueError("if in trainning, matched_idxs should not be None")
  1048. for img_id in range(num_images):
  1049. pos = torch.where(labels[img_id] > 0)[0]
  1050. keypoint_proposals.append(proposals[img_id][pos])
  1051. pos_matched_idxs.append(matched_idxs[img_id][pos])
  1052. else:
  1053. pos_matched_idxs = None
  1054. keypoint_features = self.line_roi_pool(features, keypoint_proposals, image_shapes)
  1055. keypoint_features = self.line_head(keypoint_features)
  1056. keypoint_logits = self.line_predictor(keypoint_features)
  1057. loss_keypoint = {}
  1058. if self.training:
  1059. if targets is None or pos_matched_idxs is None:
  1060. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  1061. gt_keypoints = [t["keypoints"] for t in targets]
  1062. rcnn_loss_keypoint = keypointrcnn_loss(
  1063. keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
  1064. )
  1065. loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
  1066. else:
  1067. if keypoint_logits is None or keypoint_proposals is None:
  1068. raise ValueError(
  1069. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  1070. )
  1071. keypoints_probs, lines_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
  1072. for keypoint_prob, kps, r in zip(keypoints_probs, lines_scores, result):
  1073. r["keypoints"] = keypoint_prob
  1074. r["keypoints_scores"] = kps
  1075. losses.update(loss_keypoint)
  1076. return result, losses
  1077. def line_forward1(self, features, image_shapes, line_proposals):
  1078. print(f'line_proposals:{len(line_proposals)}')
  1079. # cs_features= features['0']
  1080. # print(f'features-0:{features['0'].shape}')
  1081. cs_features = self.channel_compress(features['0'])
  1082. filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
  1083. if len(filtered_proposals) > 0:
  1084. filtered_proposals_tensor = torch.cat(filtered_proposals)
  1085. print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1086. line_proposals_tensor = torch.cat(line_proposals)
  1087. print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
  1088. roi_features = features_align(cs_features, line_proposals, image_shapes)
  1089. if roi_features is not None:
  1090. print(f'line_features from align:{roi_features.shape}')
  1091. feature_logits = self.line_head(roi_features)
  1092. print(f'feature_logits from line_head:{feature_logits.shape}')
  1093. return feature_logits
  1094. def line_forward2(self, features, image_shapes, line_proposals):
  1095. print(f'line_proposals:{len(line_proposals)}')
  1096. # cs_features= features['0']
  1097. # print(f'features-0:{features['0'].shape}')
  1098. # cs_features = self.channel_compress(features['0'])
  1099. cs_features=features['0']
  1100. filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
  1101. if len(filtered_proposals) > 0:
  1102. filtered_proposals_tensor = torch.cat(filtered_proposals)
  1103. print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1104. line_proposals=filtered_proposals
  1105. line_proposals_tensor = torch.cat(line_proposals)
  1106. print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
  1107. feature_logits = self.line_head(cs_features)
  1108. print(f'feature_logits from line_head:{feature_logits.shape}')
  1109. roi_features = features_align(feature_logits, line_proposals, image_shapes)
  1110. if roi_features is not None:
  1111. print(f'roi_features from align:{roi_features.shape}')
  1112. return roi_features
  1113. def line_forward3(self, features, image_shapes, line_proposals):
  1114. print(f'line_proposals:{len(line_proposals)}')
  1115. # cs_features= features['0']
  1116. # print(f'features-0:{features['0'].shape}')
  1117. # cs_features = self.channel_compress(features['0'])
  1118. cs_features=features['0']
  1119. # filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
  1120. #
  1121. # if len(filtered_proposals) > 0:
  1122. # filtered_proposals_tensor = torch.cat(filtered_proposals)
  1123. # print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1124. # line_proposals=filtered_proposals
  1125. # line_proposals_tensor = torch.cat(line_proposals)
  1126. # print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
  1127. feature_logits = self.line_predictor(cs_features)
  1128. print(f'feature_logits from line_head:{feature_logits.shape}')
  1129. roi_features = features_align(feature_logits, line_proposals, image_shapes)
  1130. if roi_features is not None:
  1131. print(f'roi_features from align:{roi_features.shape}')
  1132. return roi_features
  1133. def point_forward1(self, features, image_shapes, proposals):
  1134. print(f'point_proposals:{len(proposals)}')
  1135. # cs_features= features['0']
  1136. # print(f'features-0:{features['0'].shape}')
  1137. # cs_features = self.channel_compress(features['0'])
  1138. cs_features=features['0']
  1139. # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
  1140. #
  1141. # if len(filtered_proposals) > 0:
  1142. # filtered_proposals_tensor = torch.cat(filtered_proposals)
  1143. # print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1144. # proposals=filtered_proposals
  1145. # point_proposals_tensor = torch.cat(proposals)
  1146. # print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
  1147. feature_logits = self.point_predictor(cs_features)
  1148. print(f'feature_logits from line_head:{feature_logits.shape}')
  1149. roi_features = features_align(feature_logits, proposals, image_shapes)
  1150. if roi_features is not None:
  1151. print(f'roi_features from align:{roi_features.shape}')
  1152. return roi_features
  1153. def arc_forward1(self, features, image_shapes, proposals):
  1154. print(f'point_proposals:{len(proposals)}')
  1155. # cs_features= features['0']
  1156. # print(f'features-0:{features['0'].shape}')
  1157. # cs_features = self.channel_compress(features['0'])
  1158. cs_features=features['0']
  1159. # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
  1160. #
  1161. # if len(filtered_proposals) > 0:
  1162. # filtered_proposals_tensor = torch.cat(filtered_proposals)
  1163. # print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1164. # proposals=filtered_proposals
  1165. # point_proposals_tensor = torch.cat(proposals)
  1166. # print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
  1167. feature_logits = self.arc_predictor(cs_features)
  1168. print(f'feature_logits from line_head:{feature_logits.shape}')
  1169. roi_features = features_align(feature_logits, proposals, image_shapes)
  1170. if roi_features is not None:
  1171. print(f'roi_features from align:{roi_features.shape}')
  1172. return roi_features