loi_heads.py 51 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314
  1. from typing import Dict, List, Optional, Tuple
  2. import matplotlib.pyplot as plt
  3. import torch
  4. import torch.nn.functional as F
  5. import torchvision
  6. # from scipy.optimize import linear_sum_assignment
  7. from torch import nn, Tensor
  8. from libs.vision_libs.ops import boxes as box_ops, roi_align
  9. import libs.vision_libs.models.detection._utils as det_utils
  10. from collections import OrderedDict
  11. from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
  12. lines_point_pair_loss, features_align, line_inference
  13. def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
  14. # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
  15. """
  16. Computes the loss for Faster R-CNN.
  17. Args:
  18. class_logits (Tensor)
  19. box_regression (Tensor)
  20. labels (list[BoxList])
  21. regression_targets (Tensor)
  22. Returns:
  23. classification_loss (Tensor)
  24. box_loss (Tensor)
  25. """
  26. # print(f'compute fastrcnn_loss:{labels}')
  27. labels = torch.cat(labels, dim=0)
  28. regression_targets = torch.cat(regression_targets, dim=0)
  29. classification_loss = F.cross_entropy(class_logits, labels)
  30. # get indices that correspond to the regression targets for
  31. # the corresponding ground truth labels, to be used with
  32. # advanced indexing
  33. sampled_pos_inds_subset = torch.where(labels > 0)[0]
  34. labels_pos = labels[sampled_pos_inds_subset]
  35. N, num_classes = class_logits.shape
  36. box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
  37. box_loss = F.smooth_l1_loss(
  38. box_regression[sampled_pos_inds_subset, labels_pos],
  39. regression_targets[sampled_pos_inds_subset],
  40. beta=1 / 9,
  41. reduction="sum",
  42. )
  43. box_loss = box_loss / labels.numel()
  44. return classification_loss, box_loss
  45. def maskrcnn_inference(x, labels):
  46. # type: (Tensor, List[Tensor]) -> List[Tensor]
  47. """
  48. From the results of the CNN, post process the masks
  49. by taking the mask corresponding to the class with max
  50. probability (which are of fixed size and directly output
  51. by the CNN) and return the masks in the mask field of the BoxList.
  52. Args:
  53. x (Tensor): the mask logits
  54. labels (list[BoxList]): bounding boxes that are used as
  55. reference, one for ech image
  56. Returns:
  57. results (list[BoxList]): one BoxList for each image, containing
  58. the extra field mask
  59. """
  60. mask_prob = x.sigmoid()
  61. # select masks corresponding to the predicted classes
  62. num_masks = x.shape[0]
  63. boxes_per_image = [label.shape[0] for label in labels]
  64. labels = torch.cat(labels)
  65. index = torch.arange(num_masks, device=labels.device)
  66. mask_prob = mask_prob[index, labels][:, None]
  67. mask_prob = mask_prob.split(boxes_per_image, dim=0)
  68. return mask_prob
  69. def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
  70. # type: (Tensor, Tensor, Tensor, int) -> Tensor
  71. """
  72. Given segmentation masks and the bounding boxes corresponding
  73. to the location of the masks in the image, this function
  74. crops and resizes the masks in the position defined by the
  75. boxes. This prepares the masks for them to be fed to the
  76. loss computation as the targets.
  77. """
  78. matched_idxs = matched_idxs.to(boxes)
  79. rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
  80. gt_masks = gt_masks[:, None].to(rois)
  81. return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
  82. def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
  83. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  84. """
  85. Args:
  86. proposals (list[BoxList])
  87. mask_logits (Tensor)
  88. targets (list[BoxList])
  89. Return:
  90. mask_loss (Tensor): scalar tensor containing the loss
  91. """
  92. discretization_size = mask_logits.shape[-1]
  93. labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
  94. mask_targets = [
  95. project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
  96. ]
  97. labels = torch.cat(labels, dim=0)
  98. mask_targets = torch.cat(mask_targets, dim=0)
  99. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  100. # accept empty tensors, so handle it separately
  101. if mask_targets.numel() == 0:
  102. return mask_logits.sum() * 0
  103. mask_loss = F.binary_cross_entropy_with_logits(
  104. mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
  105. )
  106. return mask_loss
  107. def keypoints_to_heatmap(keypoints, rois, heatmap_size):
  108. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  109. offset_x = rois[:, 0]
  110. offset_y = rois[:, 1]
  111. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  112. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  113. offset_x = offset_x[:, None]
  114. offset_y = offset_y[:, None]
  115. scale_x = scale_x[:, None]
  116. scale_y = scale_y[:, None]
  117. x = keypoints[..., 0]
  118. y = keypoints[..., 1]
  119. x_boundary_inds = x == rois[:, 2][:, None]
  120. y_boundary_inds = y == rois[:, 3][:, None]
  121. x = (x - offset_x) * scale_x
  122. x = x.floor().long()
  123. y = (y - offset_y) * scale_y
  124. y = y.floor().long()
  125. x[x_boundary_inds] = heatmap_size - 1
  126. y[y_boundary_inds] = heatmap_size - 1
  127. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  128. vis = keypoints[..., 2] > 0
  129. valid = (valid_loc & vis).long()
  130. lin_ind = y * heatmap_size + x
  131. heatmaps = lin_ind * valid
  132. return heatmaps, valid
  133. def _onnx_heatmaps_to_keypoints(
  134. maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
  135. ):
  136. num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
  137. width_correction = widths_i / roi_map_width
  138. height_correction = heights_i / roi_map_height
  139. roi_map = F.interpolate(
  140. maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
  141. )[:, 0]
  142. w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
  143. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  144. x_int = pos % w
  145. y_int = (pos - x_int) // w
  146. x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
  147. dtype=torch.float32
  148. )
  149. y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
  150. dtype=torch.float32
  151. )
  152. xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
  153. xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
  154. xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
  155. xy_preds_i = torch.stack(
  156. [
  157. xy_preds_i_0.to(dtype=torch.float32),
  158. xy_preds_i_1.to(dtype=torch.float32),
  159. xy_preds_i_2.to(dtype=torch.float32),
  160. ],
  161. 0,
  162. )
  163. # TODO: simplify when indexing without rank will be supported by ONNX
  164. base = num_keypoints * num_keypoints + num_keypoints + 1
  165. ind = torch.arange(num_keypoints)
  166. ind = ind.to(dtype=torch.int64) * base
  167. end_scores_i = (
  168. roi_map.index_select(1, y_int.to(dtype=torch.int64))
  169. .index_select(2, x_int.to(dtype=torch.int64))
  170. .view(-1)
  171. .index_select(0, ind.to(dtype=torch.int64))
  172. )
  173. return xy_preds_i, end_scores_i
  174. @torch.jit._script_if_tracing
  175. def _onnx_heatmaps_to_keypoints_loop(
  176. maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
  177. ):
  178. xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  179. end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  180. for i in range(int(rois.size(0))):
  181. xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
  182. maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
  183. )
  184. xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
  185. end_scores = torch.cat(
  186. (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
  187. )
  188. return xy_preds, end_scores
  189. def heatmaps_to_keypoints(maps, rois):
  190. """Extract predicted keypoint locations from heatmaps. Output has shape
  191. (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
  192. for each keypoint.
  193. """
  194. # This function converts a discrete image coordinate in a HEATMAP_SIZE x
  195. # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
  196. # consistency with keypoints_to_heatmap_labels by using the conversion from
  197. # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
  198. # continuous coordinate.
  199. offset_x = rois[:, 0]
  200. offset_y = rois[:, 1]
  201. widths = rois[:, 2] - rois[:, 0]
  202. heights = rois[:, 3] - rois[:, 1]
  203. widths = widths.clamp(min=1)
  204. heights = heights.clamp(min=1)
  205. widths_ceil = widths.ceil()
  206. heights_ceil = heights.ceil()
  207. num_keypoints = maps.shape[1]
  208. if torchvision._is_tracing():
  209. xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
  210. maps,
  211. rois,
  212. widths_ceil,
  213. heights_ceil,
  214. widths,
  215. heights,
  216. offset_x,
  217. offset_y,
  218. torch.scalar_tensor(num_keypoints, dtype=torch.int64),
  219. )
  220. return xy_preds.permute(0, 2, 1), end_scores
  221. xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
  222. end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
  223. for i in range(len(rois)):
  224. roi_map_width = int(widths_ceil[i].item())
  225. roi_map_height = int(heights_ceil[i].item())
  226. width_correction = widths[i] / roi_map_width
  227. height_correction = heights[i] / roi_map_height
  228. roi_map = F.interpolate(
  229. maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
  230. )[:, 0]
  231. # roi_map_probs = scores_to_probs(roi_map.copy())
  232. w = roi_map.shape[2]
  233. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  234. x_int = pos % w
  235. y_int = torch.div(pos - x_int, w, rounding_mode="floor")
  236. # assert (roi_map_probs[k, y_int, x_int] ==
  237. # roi_map_probs[k, :, :].max())
  238. x = (x_int.float() + 0.5) * width_correction
  239. y = (y_int.float() + 0.5) * height_correction
  240. xy_preds[i, 0, :] = x + offset_x[i]
  241. xy_preds[i, 1, :] = y + offset_y[i]
  242. xy_preds[i, 2, :] = 1
  243. end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
  244. return xy_preds.permute(0, 2, 1), end_scores
  245. def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
  246. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  247. N, K, H, W = keypoint_logits.shape
  248. if H != W:
  249. raise ValueError(
  250. f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  251. )
  252. discretization_size = H
  253. heatmaps = []
  254. valid = []
  255. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
  256. kp = gt_kp_in_image[midx]
  257. heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
  258. heatmaps.append(heatmaps_per_image.view(-1))
  259. valid.append(valid_per_image.view(-1))
  260. keypoint_targets = torch.cat(heatmaps, dim=0)
  261. valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  262. valid = torch.where(valid)[0]
  263. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  264. # accept empty tensors, so handle it sepaartely
  265. if keypoint_targets.numel() == 0 or len(valid) == 0:
  266. return keypoint_logits.sum() * 0
  267. keypoint_logits = keypoint_logits.view(N * K, H * W)
  268. keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
  269. return keypoint_loss
  270. def keypointrcnn_inference(x, boxes):
  271. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  272. kp_probs = []
  273. kp_scores = []
  274. boxes_per_image = [box.size(0) for box in boxes]
  275. x2 = x.split(boxes_per_image, dim=0)
  276. for xx, bb in zip(x2, boxes):
  277. kp_prob, scores = heatmaps_to_keypoints(xx, bb)
  278. kp_probs.append(kp_prob)
  279. kp_scores.append(scores)
  280. return kp_probs, kp_scores
  281. def _onnx_expand_boxes(boxes, scale):
  282. # type: (Tensor, float) -> Tensor
  283. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  284. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  285. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  286. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  287. w_half = w_half.to(dtype=torch.float32) * scale
  288. h_half = h_half.to(dtype=torch.float32) * scale
  289. boxes_exp0 = x_c - w_half
  290. boxes_exp1 = y_c - h_half
  291. boxes_exp2 = x_c + w_half
  292. boxes_exp3 = y_c + h_half
  293. boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
  294. return boxes_exp
  295. # the next two functions should be merged inside Masker
  296. # but are kept here for the moment while we need them
  297. # temporarily for paste_mask_in_image
  298. def expand_boxes(boxes, scale):
  299. # type: (Tensor, float) -> Tensor
  300. if torchvision._is_tracing():
  301. return _onnx_expand_boxes(boxes, scale)
  302. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  303. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  304. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  305. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  306. w_half *= scale
  307. h_half *= scale
  308. boxes_exp = torch.zeros_like(boxes)
  309. boxes_exp[:, 0] = x_c - w_half
  310. boxes_exp[:, 2] = x_c + w_half
  311. boxes_exp[:, 1] = y_c - h_half
  312. boxes_exp[:, 3] = y_c + h_half
  313. return boxes_exp
  314. @torch.jit.unused
  315. def expand_masks_tracing_scale(M, padding):
  316. # type: (int, int) -> float
  317. return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
  318. def expand_masks(mask, padding):
  319. # type: (Tensor, int) -> Tuple[Tensor, float]
  320. M = mask.shape[-1]
  321. if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
  322. scale = expand_masks_tracing_scale(M, padding)
  323. else:
  324. scale = float(M + 2 * padding) / M
  325. padded_mask = F.pad(mask, (padding,) * 4)
  326. return padded_mask, scale
  327. def paste_mask_in_image(mask, box, im_h, im_w):
  328. # type: (Tensor, Tensor, int, int) -> Tensor
  329. TO_REMOVE = 1
  330. w = int(box[2] - box[0] + TO_REMOVE)
  331. h = int(box[3] - box[1] + TO_REMOVE)
  332. w = max(w, 1)
  333. h = max(h, 1)
  334. # Set shape to [batchxCxHxW]
  335. mask = mask.expand((1, 1, -1, -1))
  336. # Resize mask
  337. mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
  338. mask = mask[0][0]
  339. im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
  340. x_0 = max(box[0], 0)
  341. x_1 = min(box[2] + 1, im_w)
  342. y_0 = max(box[1], 0)
  343. y_1 = min(box[3] + 1, im_h)
  344. im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  345. return im_mask
  346. def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
  347. one = torch.ones(1, dtype=torch.int64)
  348. zero = torch.zeros(1, dtype=torch.int64)
  349. w = box[2] - box[0] + one
  350. h = box[3] - box[1] + one
  351. w = torch.max(torch.cat((w, one)))
  352. h = torch.max(torch.cat((h, one)))
  353. # Set shape to [batchxCxHxW]
  354. mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
  355. # Resize mask
  356. mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
  357. mask = mask[0][0]
  358. x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
  359. x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
  360. y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
  361. y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
  362. unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  363. # TODO : replace below with a dynamic padding when support is added in ONNX
  364. # pad y
  365. zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
  366. zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
  367. concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
  368. # pad x
  369. zeros_x0 = torch.zeros(concat_0.size(0), x_0)
  370. zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
  371. im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
  372. return im_mask
  373. @torch.jit._script_if_tracing
  374. def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
  375. res_append = torch.zeros(0, im_h, im_w)
  376. for i in range(masks.size(0)):
  377. mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
  378. mask_res = mask_res.unsqueeze(0)
  379. res_append = torch.cat((res_append, mask_res))
  380. return res_append
  381. def paste_masks_in_image(masks, boxes, img_shape, padding=1):
  382. # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
  383. masks, scale = expand_masks(masks, padding=padding)
  384. boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
  385. im_h, im_w = img_shape
  386. if torchvision._is_tracing():
  387. return _onnx_paste_masks_in_image_loop(
  388. masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
  389. )[:, None]
  390. res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
  391. if len(res) > 0:
  392. ret = torch.stack(res, dim=0)[:, None]
  393. else:
  394. ret = masks.new_empty((0, 1, im_h, im_w))
  395. return ret
  396. class RoIHeads(nn.Module):
  397. __annotations__ = {
  398. "box_coder": det_utils.BoxCoder,
  399. "proposal_matcher": det_utils.Matcher,
  400. "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
  401. }
  402. def __init__(
  403. self,
  404. box_roi_pool,
  405. box_head,
  406. box_predictor,
  407. # Faster R-CNN training
  408. fg_iou_thresh,
  409. bg_iou_thresh,
  410. batch_size_per_image,
  411. positive_fraction,
  412. bbox_reg_weights,
  413. # Faster R-CNN inference
  414. score_thresh,
  415. nms_thresh,
  416. detections_per_img,
  417. # Line
  418. line_roi_pool=None,
  419. line_head=None,
  420. line_predictor=None,
  421. # point parameters
  422. point_roi_pool=None,
  423. point_head=None,
  424. point_predictor=None,
  425. # Mask
  426. mask_roi_pool=None,
  427. mask_head=None,
  428. mask_predictor=None,
  429. keypoint_roi_pool=None,
  430. keypoint_head=None,
  431. keypoint_predictor=None,
  432. detect_point=True,
  433. detect_line=True,
  434. detect_arc=False,
  435. ):
  436. super().__init__()
  437. self.box_similarity = box_ops.box_iou
  438. # assign ground-truth boxes for each proposal
  439. self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
  440. self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
  441. if bbox_reg_weights is None:
  442. bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
  443. self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
  444. self.box_roi_pool = box_roi_pool
  445. self.box_head = box_head
  446. self.box_predictor = box_predictor
  447. self.score_thresh = score_thresh
  448. self.nms_thresh = nms_thresh
  449. self.detections_per_img = detections_per_img
  450. self.line_roi_pool = line_roi_pool
  451. self.line_head = line_head
  452. self.line_predictor = line_predictor
  453. self.point_roi_pool = point_roi_pool
  454. self.point_head = point_head
  455. self.point_predictor = point_predictor
  456. self.mask_roi_pool = mask_roi_pool
  457. self.mask_head = mask_head
  458. self.mask_predictor = mask_predictor
  459. self.keypoint_roi_pool = keypoint_roi_pool
  460. self.keypoint_head = keypoint_head
  461. self.keypoint_predictor = keypoint_predictor
  462. self.detect_point =detect_point
  463. self.detect_line =detect_line
  464. self.detect_arc =detect_arc
  465. self.channel_compress = nn.Sequential(
  466. nn.Conv2d(256, 8, kernel_size=1),
  467. nn.BatchNorm2d(8),
  468. nn.ReLU(inplace=True)
  469. )
  470. def has_mask(self):
  471. if self.mask_roi_pool is None:
  472. return False
  473. if self.mask_head is None:
  474. return False
  475. if self.mask_predictor is None:
  476. return False
  477. return True
  478. def has_keypoint(self):
  479. if self.keypoint_roi_pool is None:
  480. return False
  481. if self.keypoint_head is None:
  482. return False
  483. if self.keypoint_predictor is None:
  484. return False
  485. return True
  486. def has_line(self):
  487. # if self.line_roi_pool is None:
  488. # return False
  489. if self.line_head is None:
  490. return False
  491. # if self.line_predictor is None:
  492. # return False
  493. return True
  494. def has_point(self):
  495. # if self.line_roi_pool is None:
  496. # return False
  497. if self.point_head is None:
  498. return False
  499. # if self.line_predictor is None:
  500. # return False
  501. return True
  502. def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
  503. # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  504. matched_idxs = []
  505. labels = []
  506. for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
  507. if gt_boxes_in_image.numel() == 0:
  508. # Background image
  509. device = proposals_in_image.device
  510. clamped_matched_idxs_in_image = torch.zeros(
  511. (proposals_in_image.shape[0],), dtype=torch.int64, device=device
  512. )
  513. labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
  514. else:
  515. # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
  516. match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
  517. matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
  518. clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
  519. labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
  520. labels_in_image = labels_in_image.to(dtype=torch.int64)
  521. # Label background (below the low threshold)
  522. bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
  523. labels_in_image[bg_inds] = 0
  524. # Label ignore proposals (between low and high thresholds)
  525. ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
  526. labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
  527. matched_idxs.append(clamped_matched_idxs_in_image)
  528. labels.append(labels_in_image)
  529. return matched_idxs, labels
  530. def subsample(self, labels):
  531. # type: (List[Tensor]) -> List[Tensor]
  532. sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  533. sampled_inds = []
  534. for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
  535. img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
  536. sampled_inds.append(img_sampled_inds)
  537. return sampled_inds
  538. def add_gt_proposals(self, proposals, gt_boxes):
  539. # type: (List[Tensor], List[Tensor]) -> List[Tensor]
  540. proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
  541. return proposals
  542. def check_targets(self, targets):
  543. # type: (Optional[List[Dict[str, Tensor]]]) -> None
  544. if targets is None:
  545. raise ValueError("targets should not be None")
  546. if not all(["boxes" in t for t in targets]):
  547. raise ValueError("Every element of targets should have a boxes key")
  548. if not all(["labels" in t for t in targets]):
  549. raise ValueError("Every element of targets should have a labels key")
  550. if self.has_mask():
  551. if not all(["masks" in t for t in targets]):
  552. raise ValueError("Every element of targets should have a masks key")
  553. def select_training_samples(
  554. self,
  555. proposals, # type: List[Tensor]
  556. targets, # type: Optional[List[Dict[str, Tensor]]]
  557. ):
  558. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
  559. self.check_targets(targets)
  560. if targets is None:
  561. raise ValueError("targets should not be None")
  562. dtype = proposals[0].dtype
  563. device = proposals[0].device
  564. gt_boxes = [t["boxes"].to(dtype) for t in targets]
  565. gt_labels = [t["labels"] for t in targets]
  566. # append ground-truth bboxes to propos
  567. proposals = self.add_gt_proposals(proposals, gt_boxes)
  568. # get matching gt indices for each proposal
  569. matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
  570. # sample a fixed proportion of positive-negative proposals
  571. sampled_inds = self.subsample(labels)
  572. matched_gt_boxes = []
  573. num_images = len(proposals)
  574. for img_id in range(num_images):
  575. img_sampled_inds = sampled_inds[img_id]
  576. proposals[img_id] = proposals[img_id][img_sampled_inds]
  577. labels[img_id] = labels[img_id][img_sampled_inds]
  578. matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
  579. gt_boxes_in_image = gt_boxes[img_id]
  580. if gt_boxes_in_image.numel() == 0:
  581. gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
  582. matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
  583. regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
  584. return proposals, matched_idxs, labels, regression_targets
  585. def postprocess_detections(
  586. self,
  587. class_logits, # type: Tensor
  588. box_regression, # type: Tensor
  589. proposals, # type: List[Tensor]
  590. image_shapes, # type: List[Tuple[int, int]]
  591. ):
  592. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
  593. device = class_logits.device
  594. num_classes = class_logits.shape[-1]
  595. boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
  596. pred_boxes = self.box_coder.decode(box_regression, proposals)
  597. pred_scores = F.softmax(class_logits, -1)
  598. pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
  599. pred_scores_list = pred_scores.split(boxes_per_image, 0)
  600. all_boxes = []
  601. all_scores = []
  602. all_labels = []
  603. for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
  604. boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
  605. # create labels for each prediction
  606. labels = torch.arange(num_classes, device=device)
  607. labels = labels.view(1, -1).expand_as(scores)
  608. # remove predictions with the background label
  609. boxes = boxes[:, 1:]
  610. scores = scores[:, 1:]
  611. labels = labels[:, 1:]
  612. # batch everything, by making every class prediction be a separate instance
  613. boxes = boxes.reshape(-1, 4)
  614. scores = scores.reshape(-1)
  615. labels = labels.reshape(-1)
  616. # remove low scoring boxes
  617. inds = torch.where(scores > self.score_thresh)[0]
  618. boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
  619. # remove empty boxes
  620. keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
  621. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  622. # non-maximum suppression, independently done per class
  623. keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
  624. # keep only topk scoring predictions
  625. keep = keep[: self.detections_per_img]
  626. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  627. all_boxes.append(boxes)
  628. all_scores.append(scores)
  629. all_labels.append(labels)
  630. return all_boxes, all_scores, all_labels
  631. def forward(
  632. self,
  633. features, # type: Dict[str, Tensor]
  634. proposals, # type: List[Tensor]
  635. image_shapes, # type: List[Tuple[int, int]]
  636. targets=None, # type: Optional[List[Dict[str, Tensor]]]
  637. ):
  638. # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
  639. """
  640. Args:
  641. features (List[Tensor])
  642. proposals (List[Tensor[N, 4]])
  643. image_shapes (List[Tuple[H, W]])
  644. targets (List[Dict])
  645. """
  646. print(f'roihead forward!!!')
  647. if targets is not None:
  648. for t in targets:
  649. # TODO: https://github.com/pytorch/pytorch/issues/26731
  650. floating_point_types = (torch.float, torch.double, torch.half)
  651. if not t["boxes"].dtype in floating_point_types:
  652. raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
  653. if not t["labels"].dtype == torch.int64:
  654. raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
  655. if self.has_keypoint():
  656. if not t["keypoints"].dtype == torch.float32:
  657. raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
  658. if self.training:
  659. proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  660. else:
  661. if targets is not None:
  662. proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  663. else:
  664. labels = None
  665. regression_targets = None
  666. matched_idxs = None
  667. device=features['0'].device
  668. box_features = self.box_roi_pool(features, proposals, image_shapes)
  669. box_features = self.box_head(box_features)
  670. class_logits, box_regression = self.box_predictor(box_features)
  671. result: List[Dict[str, torch.Tensor]] = []
  672. losses = {}
  673. # _, C, H, W = features['0'].shape # 忽略 batch_size,因为我们只关心 C, H, W
  674. if self.training:
  675. if labels is None:
  676. raise ValueError("labels cannot be None")
  677. if regression_targets is None:
  678. raise ValueError("regression_targets cannot be None")
  679. print(f'boxes compute losses')
  680. loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  681. losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  682. else:
  683. if targets is not None:
  684. loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  685. losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  686. boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals,
  687. image_shapes)
  688. num_images = len(boxes)
  689. for i in range(num_images):
  690. result.append(
  691. {
  692. "boxes": boxes[i],
  693. "labels": labels[i],
  694. "scores": scores[i],
  695. }
  696. )
  697. if self.has_line() and self.detect_line:
  698. print(f'roi_heads forward has_line()!!!!')
  699. # print(f'labels:{labels}')
  700. line_proposals = [p["boxes"] for p in result]
  701. point_proposals = [p["boxes"] for p in result]
  702. print(f'boxes_proposals:{len(line_proposals)}')
  703. # if line_proposals is None or len(line_proposals) == 0:
  704. # # 返回空特征或者跳过该部分计算
  705. # return torch.empty(0, C, H, W).to(features['0'].device)
  706. if self.training:
  707. # during training, only focus on positive boxes
  708. num_images = len(proposals)
  709. print(f'num_images:{num_images}')
  710. line_proposals = []
  711. point_proposals = []
  712. arc_proposals = []
  713. pos_matched_idxs = []
  714. line_pos_matched_idxs = []
  715. point_pos_matched_idxs = []
  716. if matched_idxs is None:
  717. raise ValueError("if in trainning, matched_idxs should not be None")
  718. for img_id in range(num_images):
  719. pos = torch.where(labels[img_id] > 0)[0]
  720. line_pos=torch.where(labels[img_id] ==2)[0]
  721. # point_pos=torch.where(labels[img_id] ==1)[0]
  722. line_proposals.append(proposals[img_id][line_pos])
  723. # point_proposals.append(proposals[img_id][point_pos])
  724. line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
  725. # point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
  726. # pos_matched_idxs.append(matched_idxs[img_id][pos])
  727. else:
  728. if targets is not None:
  729. pos_matched_idxs = []
  730. num_images = len(proposals)
  731. line_proposals = []
  732. line_pos_matched_idxs = []
  733. print(f'val num_images:{num_images}')
  734. if matched_idxs is None:
  735. raise ValueError("if in trainning, matched_idxs should not be None")
  736. for img_id in range(num_images):
  737. # pos = torch.where(labels[img_id] > 0)[0]
  738. line_pos = torch.where(labels[img_id] == 2)[0]
  739. line_proposals.append(proposals[img_id][line_pos])
  740. line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
  741. else:
  742. pos_matched_idxs = None
  743. feature_logits = self.line_forward3(features, image_shapes, line_proposals)
  744. loss_line = None
  745. loss_line_iou =None
  746. if self.training:
  747. if targets is None or pos_matched_idxs is None:
  748. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  749. gt_lines = [t["lines"] for t in targets if "lines" in t]
  750. print(f'gt_lines:{gt_lines[0].shape}')
  751. h, w = targets[0]["img_size"]
  752. img_size = h
  753. gt_lines_tensor=torch.zeros(0,0)
  754. if len(gt_lines)>0:
  755. gt_lines_tensor = torch.cat(gt_lines)
  756. print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
  757. if gt_lines_tensor.shape[0]>0 :
  758. print(f'start to lines_point_pair_loss')
  759. loss_line = lines_point_pair_loss(
  760. feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
  761. )
  762. loss_line_iou = line_iou_loss(feature_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
  763. if loss_line is None:
  764. print(f'loss_line is None111')
  765. loss_line = torch.tensor(0.0, device=device)
  766. if loss_line_iou is None:
  767. print(f'loss_line_iou is None111')
  768. loss_line_iou = torch.tensor(0.0, device=device)
  769. loss_line = {"loss_line": loss_line}
  770. loss_line_iou = {'loss_line_iou': loss_line_iou}
  771. else:
  772. if targets is not None:
  773. h, w = targets[0]["img_size"]
  774. img_size = h
  775. gt_lines = [t["lines"] for t in targets if "lines" in t]
  776. gt_lines_tensor = torch.zeros(0, 0)
  777. if len(gt_lines)>0:
  778. gt_lines_tensor = torch.cat(gt_lines)
  779. if gt_lines_tensor.shape[0] > 0 and feature_logits is not None:
  780. loss_line = lines_point_pair_loss(
  781. feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
  782. )
  783. print(f'compute_line_loss:{loss_line}')
  784. loss_line_iou = line_iou_loss(feature_logits , line_proposals, gt_lines, line_pos_matched_idxs,
  785. img_size)
  786. if loss_line is None:
  787. print(f'loss_line is None')
  788. loss_line=torch.tensor(0.0,device=device)
  789. if loss_line_iou is None:
  790. print(f'loss_line_iou is None')
  791. loss_line_iou=torch.tensor(0.0,device=device)
  792. loss_line = {"loss_line": loss_line}
  793. loss_line_iou = {'loss_line_iou': loss_line_iou}
  794. else:
  795. loss_line = {}
  796. loss_line_iou = {}
  797. if feature_logits is None or line_proposals is None:
  798. raise ValueError(
  799. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  800. )
  801. if feature_logits is not None:
  802. lines_probs, lines_scores = line_inference(feature_logits,line_proposals)
  803. for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
  804. r["lines"] = keypoint_prob
  805. r["lines_scores"] = kps
  806. print(f'loss_line11111:{loss_line}')
  807. losses.update(loss_line)
  808. losses.update(loss_line_iou)
  809. print(f'losses:{losses}')
  810. if self.has_point() and self.detect_point:
  811. print(f'roi_heads forward has_point()!!!!')
  812. # print(f'labels:{labels}')
  813. point_proposals = [p["boxes"] for p in result]
  814. print(f'boxes_proposals:{len(point_proposals)}')
  815. # if line_proposals is None or len(line_proposals) == 0:
  816. # # 返回空特征或者跳过该部分计算
  817. # return torch.empty(0, C, H, W).to(features['0'].device)
  818. if self.training:
  819. # during training, only focus on positive boxes
  820. num_images = len(proposals)
  821. print(f'num_images:{num_images}')
  822. point_proposals = []
  823. point_pos_matched_idxs = []
  824. if matched_idxs is None:
  825. raise ValueError("if in trainning, matched_idxs should not be None")
  826. for img_id in range(num_images):
  827. point_pos=torch.where(labels[img_id] ==1)[0]
  828. point_proposals.append(proposals[img_id][point_pos])
  829. point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
  830. else:
  831. if targets is not None:
  832. num_images = len(proposals)
  833. point_proposals = []
  834. point_pos_matched_idxs = []
  835. print(f'val num_images:{num_images}')
  836. if matched_idxs is None:
  837. raise ValueError("if in trainning, matched_idxs should not be None")
  838. for img_id in range(num_images):
  839. point_pos = torch.where(labels[img_id] == 1)[0]
  840. point_proposals.append(proposals[img_id][point_pos])
  841. point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
  842. else:
  843. pos_matched_idxs = None
  844. feature_logits = self.point_forward1(features, image_shapes, point_proposals)
  845. loss_point=None
  846. if self.training:
  847. if targets is None or point_pos_matched_idxs is None:
  848. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  849. gt_points = [t["points"] for t in targets if "points" in t]
  850. print(f'gt_points:{gt_points[0].shape}')
  851. h, w = targets[0]["img_size"]
  852. img_size = h
  853. gt_points_tensor = torch.zeros(0, 0)
  854. if len(gt_points) > 0:
  855. gt_points_tensor = torch.cat(gt_points)
  856. print(f'gt_points_tensor:{gt_points_tensor.shape}')
  857. if gt_points_tensor.shape[0] > 0:
  858. print(f'start to compute point_loss')
  859. loss_point=compute_point_loss(feature_logits,point_proposals,gt_points,point_pos_matched_idxs)
  860. if loss_point is None:
  861. print(f'loss_point is None111')
  862. loss_point = torch.tensor(0.0, device=device)
  863. loss_point = {"loss_point": loss_point}
  864. else:
  865. if targets is not None:
  866. h, w = targets[0]["img_size"]
  867. img_size = h
  868. gt_points = [t["points"] for t in targets if "points" in t]
  869. gt_points_tensor = torch.zeros(0, 0)
  870. if len(gt_points) > 0:
  871. gt_points_tensor = torch.cat(gt_points)
  872. print(f'gt_points_tensor:{gt_points_tensor.shape}')
  873. if gt_points_tensor.shape[0] > 0:
  874. print(f'start to compute point_loss')
  875. loss_point = compute_point_loss(feature_logits, point_proposals, gt_points,
  876. point_pos_matched_idxs, img_size)
  877. if loss_point is None:
  878. print(f'loss_point is None111')
  879. loss_point = torch.tensor(0.0, device=device)
  880. loss_point = {"loss_point": loss_point}
  881. else:
  882. loss_point = {}
  883. if feature_logits is None or point_proposals is None:
  884. raise ValueError(
  885. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  886. )
  887. if feature_logits is not None:
  888. points_probs, points_scores = point_inference(feature_logits,point_proposals)
  889. for keypoint_prob, kps, r in zip(points_probs, points_scores, result):
  890. r["points"] = keypoint_prob
  891. r["points_scores"] = kps
  892. print(f'loss_point:{loss_point}')
  893. losses.update(loss_point)
  894. print(f'losses:{losses}')
  895. if self.has_mask():
  896. mask_proposals = [p["boxes"] for p in result]
  897. if self.training:
  898. if matched_idxs is None:
  899. raise ValueError("if in training, matched_idxs should not be None")
  900. # during training, only focus on positive boxes
  901. num_images = len(proposals)
  902. mask_proposals = []
  903. pos_matched_idxs = []
  904. for img_id in range(num_images):
  905. pos = torch.where(labels[img_id] > 0)[0]
  906. mask_proposals.append(proposals[img_id][pos])
  907. pos_matched_idxs.append(matched_idxs[img_id][pos])
  908. else:
  909. pos_matched_idxs = None
  910. if self.mask_roi_pool is not None:
  911. mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
  912. mask_features = self.mask_head(mask_features)
  913. mask_logits = self.mask_predictor(mask_features)
  914. else:
  915. raise Exception("Expected mask_roi_pool to be not None")
  916. loss_mask = {}
  917. if self.training:
  918. if targets is None or pos_matched_idxs is None or mask_logits is None:
  919. raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
  920. gt_masks = [t["masks"] for t in targets]
  921. gt_labels = [t["labels"] for t in targets]
  922. rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
  923. loss_mask = {"loss_mask": rcnn_loss_mask}
  924. else:
  925. labels = [r["labels"] for r in result]
  926. masks_probs = maskrcnn_inference(mask_logits, labels)
  927. for mask_prob, r in zip(masks_probs, result):
  928. r["masks"] = mask_prob
  929. losses.update(loss_mask)
  930. # keep none checks in if conditional so torchscript will conditionally
  931. # compile each branch
  932. if self.has_keypoint():
  933. keypoint_proposals = [p["boxes"] for p in result]
  934. if self.training:
  935. # during training, only focus on positive boxes
  936. num_images = len(proposals)
  937. keypoint_proposals = []
  938. pos_matched_idxs = []
  939. if matched_idxs is None:
  940. raise ValueError("if in trainning, matched_idxs should not be None")
  941. for img_id in range(num_images):
  942. pos = torch.where(labels[img_id] > 0)[0]
  943. keypoint_proposals.append(proposals[img_id][pos])
  944. pos_matched_idxs.append(matched_idxs[img_id][pos])
  945. else:
  946. pos_matched_idxs = None
  947. keypoint_features = self.line_roi_pool(features, keypoint_proposals, image_shapes)
  948. keypoint_features = self.line_head(keypoint_features)
  949. keypoint_logits = self.line_predictor(keypoint_features)
  950. loss_keypoint = {}
  951. if self.training:
  952. if targets is None or pos_matched_idxs is None:
  953. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  954. gt_keypoints = [t["keypoints"] for t in targets]
  955. rcnn_loss_keypoint = keypointrcnn_loss(
  956. keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
  957. )
  958. loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
  959. else:
  960. if keypoint_logits is None or keypoint_proposals is None:
  961. raise ValueError(
  962. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  963. )
  964. keypoints_probs, lines_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
  965. for keypoint_prob, kps, r in zip(keypoints_probs, lines_scores, result):
  966. r["keypoints"] = keypoint_prob
  967. r["keypoints_scores"] = kps
  968. losses.update(loss_keypoint)
  969. return result, losses
  970. def line_forward1(self, features, image_shapes, line_proposals):
  971. print(f'line_proposals:{len(line_proposals)}')
  972. # cs_features= features['0']
  973. print(f'features-0:{features['0'].shape}')
  974. cs_features = self.channel_compress(features['0'])
  975. filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
  976. if len(filtered_proposals) > 0:
  977. filtered_proposals_tensor = torch.cat(filtered_proposals)
  978. print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  979. line_proposals_tensor = torch.cat(line_proposals)
  980. print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
  981. roi_features = features_align(cs_features, line_proposals, image_shapes)
  982. if roi_features is not None:
  983. print(f'line_features from align:{roi_features.shape}')
  984. feature_logits = self.line_head(roi_features)
  985. print(f'feature_logits from line_head:{feature_logits.shape}')
  986. return feature_logits
  987. def line_forward2(self, features, image_shapes, line_proposals):
  988. print(f'line_proposals:{len(line_proposals)}')
  989. # cs_features= features['0']
  990. print(f'features-0:{features['0'].shape}')
  991. # cs_features = self.channel_compress(features['0'])
  992. cs_features=features['0']
  993. filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
  994. if len(filtered_proposals) > 0:
  995. filtered_proposals_tensor = torch.cat(filtered_proposals)
  996. print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  997. line_proposals=filtered_proposals
  998. line_proposals_tensor = torch.cat(line_proposals)
  999. print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
  1000. feature_logits = self.line_head(cs_features)
  1001. print(f'feature_logits from line_head:{feature_logits.shape}')
  1002. roi_features = features_align(cs_features, line_proposals, image_shapes)
  1003. if roi_features is not None:
  1004. print(f'roi_features from align:{roi_features.shape}')
  1005. return roi_features
  1006. def line_forward3(self, features, image_shapes, line_proposals):
  1007. print(f'line_proposals:{len(line_proposals)}')
  1008. # cs_features= features['0']
  1009. print(f'features-0:{features['0'].shape}')
  1010. # cs_features = self.channel_compress(features['0'])
  1011. cs_features=features['0']
  1012. filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
  1013. if len(filtered_proposals) > 0:
  1014. filtered_proposals_tensor = torch.cat(filtered_proposals)
  1015. print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1016. line_proposals=filtered_proposals
  1017. line_proposals_tensor = torch.cat(line_proposals)
  1018. print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
  1019. feature_logits = self.line_predictor(cs_features)
  1020. print(f'feature_logits from line_head:{feature_logits.shape}')
  1021. roi_features = features_align(cs_features, line_proposals, image_shapes)
  1022. if roi_features is not None:
  1023. print(f'roi_features from align:{roi_features.shape}')
  1024. return roi_features
  1025. def point_forward1(self, features, image_shapes, proposals):
  1026. print(f'point_proposals:{len(proposals)}')
  1027. # cs_features= features['0']
  1028. print(f'features-0:{features['0'].shape}')
  1029. # cs_features = self.channel_compress(features['0'])
  1030. cs_features=features['0']
  1031. filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
  1032. if len(filtered_proposals) > 0:
  1033. filtered_proposals_tensor = torch.cat(filtered_proposals)
  1034. print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
  1035. proposals=filtered_proposals
  1036. point_proposals_tensor = torch.cat(proposals)
  1037. print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
  1038. feature_logits = self.point_predictor(cs_features)
  1039. print(f'feature_logits from line_head:{feature_logits.shape}')
  1040. roi_features = features_align(cs_features, proposals, image_shapes)
  1041. if roi_features is not None:
  1042. print(f'roi_features from align:{roi_features.shape}')
  1043. return roi_features