roi_heads.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. from typing import Dict, List, Optional, Tuple
  2. import torch
  3. import torch.nn.functional as F
  4. import torchvision
  5. from torch import nn, Tensor
  6. from torchvision.ops import boxes as box_ops, roi_align
  7. from . import _utils as det_utils
  8. def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
  9. # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
  10. """
  11. Computes the loss for Faster R-CNN.
  12. Args:
  13. class_logits (Tensor)
  14. box_regression (Tensor)
  15. labels (list[BoxList])
  16. regression_targets (Tensor)
  17. Returns:
  18. classification_loss (Tensor)
  19. box_loss (Tensor)
  20. """
  21. labels = torch.cat(labels, dim=0)
  22. regression_targets = torch.cat(regression_targets, dim=0)
  23. classification_loss = F.cross_entropy(class_logits, labels)
  24. # get indices that correspond to the regression targets for
  25. # the corresponding ground truth labels, to be used with
  26. # advanced indexing
  27. sampled_pos_inds_subset = torch.where(labels > 0)[0]
  28. labels_pos = labels[sampled_pos_inds_subset]
  29. N, num_classes = class_logits.shape
  30. box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
  31. box_loss = F.smooth_l1_loss(
  32. box_regression[sampled_pos_inds_subset, labels_pos],
  33. regression_targets[sampled_pos_inds_subset],
  34. beta=1 / 9,
  35. reduction="sum",
  36. )
  37. box_loss = box_loss / labels.numel()
  38. return classification_loss, box_loss
  39. def maskrcnn_inference(x, labels):
  40. # type: (Tensor, List[Tensor]) -> List[Tensor]
  41. """
  42. From the results of the CNN, post process the masks
  43. by taking the mask corresponding to the class with max
  44. probability (which are of fixed size and directly output
  45. by the CNN) and return the masks in the mask field of the BoxList.
  46. Args:
  47. x (Tensor): the mask logits
  48. labels (list[BoxList]): bounding boxes that are used as
  49. reference, one for ech image
  50. Returns:
  51. results (list[BoxList]): one BoxList for each image, containing
  52. the extra field mask
  53. """
  54. mask_prob = x.sigmoid()
  55. # select masks corresponding to the predicted classes
  56. num_masks = x.shape[0]
  57. boxes_per_image = [label.shape[0] for label in labels]
  58. labels = torch.cat(labels)
  59. index = torch.arange(num_masks, device=labels.device)
  60. mask_prob = mask_prob[index, labels][:, None]
  61. mask_prob = mask_prob.split(boxes_per_image, dim=0)
  62. return mask_prob
  63. def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
  64. # type: (Tensor, Tensor, Tensor, int) -> Tensor
  65. """
  66. Given segmentation masks and the bounding boxes corresponding
  67. to the location of the masks in the image, this function
  68. crops and resizes the masks in the position defined by the
  69. boxes. This prepares the masks for them to be fed to the
  70. loss computation as the targets.
  71. """
  72. matched_idxs = matched_idxs.to(boxes)
  73. rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
  74. gt_masks = gt_masks[:, None].to(rois)
  75. return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
  76. def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
  77. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  78. """
  79. Args:
  80. proposals (list[BoxList])
  81. mask_logits (Tensor)
  82. targets (list[BoxList])
  83. Return:
  84. mask_loss (Tensor): scalar tensor containing the loss
  85. """
  86. discretization_size = mask_logits.shape[-1]
  87. labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
  88. mask_targets = [
  89. project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
  90. ]
  91. labels = torch.cat(labels, dim=0)
  92. mask_targets = torch.cat(mask_targets, dim=0)
  93. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  94. # accept empty tensors, so handle it separately
  95. if mask_targets.numel() == 0:
  96. return mask_logits.sum() * 0
  97. mask_loss = F.binary_cross_entropy_with_logits(
  98. mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
  99. )
  100. return mask_loss
  101. def keypoints_to_heatmap(keypoints, rois, heatmap_size):
  102. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  103. offset_x = rois[:, 0]
  104. offset_y = rois[:, 1]
  105. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  106. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  107. offset_x = offset_x[:, None]
  108. offset_y = offset_y[:, None]
  109. scale_x = scale_x[:, None]
  110. scale_y = scale_y[:, None]
  111. x = keypoints[..., 0]
  112. y = keypoints[..., 1]
  113. x_boundary_inds = x == rois[:, 2][:, None]
  114. y_boundary_inds = y == rois[:, 3][:, None]
  115. x = (x - offset_x) * scale_x
  116. x = x.floor().long()
  117. y = (y - offset_y) * scale_y
  118. y = y.floor().long()
  119. x[x_boundary_inds] = heatmap_size - 1
  120. y[y_boundary_inds] = heatmap_size - 1
  121. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  122. vis = keypoints[..., 2] > 0
  123. valid = (valid_loc & vis).long()
  124. lin_ind = y * heatmap_size + x
  125. heatmaps = lin_ind * valid
  126. return heatmaps, valid
  127. def _onnx_heatmaps_to_keypoints(
  128. maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
  129. ):
  130. num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
  131. width_correction = widths_i / roi_map_width
  132. height_correction = heights_i / roi_map_height
  133. roi_map = F.interpolate(
  134. maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
  135. )[:, 0]
  136. w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
  137. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  138. x_int = pos % w
  139. y_int = (pos - x_int) // w
  140. x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
  141. dtype=torch.float32
  142. )
  143. y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
  144. dtype=torch.float32
  145. )
  146. xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
  147. xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
  148. xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
  149. xy_preds_i = torch.stack(
  150. [
  151. xy_preds_i_0.to(dtype=torch.float32),
  152. xy_preds_i_1.to(dtype=torch.float32),
  153. xy_preds_i_2.to(dtype=torch.float32),
  154. ],
  155. 0,
  156. )
  157. # TODO: simplify when indexing without rank will be supported by ONNX
  158. base = num_keypoints * num_keypoints + num_keypoints + 1
  159. ind = torch.arange(num_keypoints)
  160. ind = ind.to(dtype=torch.int64) * base
  161. end_scores_i = (
  162. roi_map.index_select(1, y_int.to(dtype=torch.int64))
  163. .index_select(2, x_int.to(dtype=torch.int64))
  164. .view(-1)
  165. .index_select(0, ind.to(dtype=torch.int64))
  166. )
  167. return xy_preds_i, end_scores_i
  168. @torch.jit._script_if_tracing
  169. def _onnx_heatmaps_to_keypoints_loop(
  170. maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
  171. ):
  172. xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  173. end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  174. for i in range(int(rois.size(0))):
  175. xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
  176. maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
  177. )
  178. xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
  179. end_scores = torch.cat(
  180. (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
  181. )
  182. return xy_preds, end_scores
  183. def heatmaps_to_keypoints(maps, rois):
  184. """Extract predicted keypoint locations from heatmaps. Output has shape
  185. (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
  186. for each keypoint.
  187. """
  188. # This function converts a discrete image coordinate in a HEATMAP_SIZE x
  189. # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
  190. # consistency with keypoints_to_heatmap_labels by using the conversion from
  191. # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
  192. # continuous coordinate.
  193. offset_x = rois[:, 0]
  194. offset_y = rois[:, 1]
  195. widths = rois[:, 2] - rois[:, 0]
  196. heights = rois[:, 3] - rois[:, 1]
  197. widths = widths.clamp(min=1)
  198. heights = heights.clamp(min=1)
  199. widths_ceil = widths.ceil()
  200. heights_ceil = heights.ceil()
  201. num_keypoints = maps.shape[1]
  202. if torchvision._is_tracing():
  203. xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
  204. maps,
  205. rois,
  206. widths_ceil,
  207. heights_ceil,
  208. widths,
  209. heights,
  210. offset_x,
  211. offset_y,
  212. torch.scalar_tensor(num_keypoints, dtype=torch.int64),
  213. )
  214. return xy_preds.permute(0, 2, 1), end_scores
  215. xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
  216. end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
  217. for i in range(len(rois)):
  218. roi_map_width = int(widths_ceil[i].item())
  219. roi_map_height = int(heights_ceil[i].item())
  220. width_correction = widths[i] / roi_map_width
  221. height_correction = heights[i] / roi_map_height
  222. roi_map = F.interpolate(
  223. maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
  224. )[:, 0]
  225. # roi_map_probs = scores_to_probs(roi_map.copy())
  226. w = roi_map.shape[2]
  227. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  228. x_int = pos % w
  229. y_int = torch.div(pos - x_int, w, rounding_mode="floor")
  230. # assert (roi_map_probs[k, y_int, x_int] ==
  231. # roi_map_probs[k, :, :].max())
  232. x = (x_int.float() + 0.5) * width_correction
  233. y = (y_int.float() + 0.5) * height_correction
  234. xy_preds[i, 0, :] = x + offset_x[i]
  235. xy_preds[i, 1, :] = y + offset_y[i]
  236. xy_preds[i, 2, :] = 1
  237. end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
  238. return xy_preds.permute(0, 2, 1), end_scores
  239. def roi_line_loss(keypoints, rois, heatmap_size):
  240. pass
  241. def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
  242. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  243. N, K, H, W = keypoint_logits.shape
  244. if H != W:
  245. raise ValueError(
  246. f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  247. )
  248. discretization_size = H
  249. heatmaps = []
  250. valid = []
  251. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
  252. kp = gt_kp_in_image[midx]
  253. heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
  254. heatmaps.append(heatmaps_per_image.view(-1))
  255. valid.append(valid_per_image.view(-1))
  256. keypoint_targets = torch.cat(heatmaps, dim=0)
  257. valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  258. valid = torch.where(valid)[0]
  259. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  260. # accept empty tensors, so handle it sepaartely
  261. if keypoint_targets.numel() == 0 or len(valid) == 0:
  262. return keypoint_logits.sum() * 0
  263. keypoint_logits = keypoint_logits.view(N * K, H * W)
  264. keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
  265. return keypoint_loss
  266. def keypointrcnn_inference(x, boxes):
  267. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  268. kp_probs = []
  269. kp_scores = []
  270. boxes_per_image = [box.size(0) for box in boxes]
  271. x2 = x.split(boxes_per_image, dim=0)
  272. for xx, bb in zip(x2, boxes):
  273. kp_prob, scores = heatmaps_to_keypoints(xx, bb)
  274. kp_probs.append(kp_prob)
  275. kp_scores.append(scores)
  276. return kp_probs, kp_scores
  277. def _onnx_expand_boxes(boxes, scale):
  278. # type: (Tensor, float) -> Tensor
  279. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  280. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  281. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  282. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  283. w_half = w_half.to(dtype=torch.float32) * scale
  284. h_half = h_half.to(dtype=torch.float32) * scale
  285. boxes_exp0 = x_c - w_half
  286. boxes_exp1 = y_c - h_half
  287. boxes_exp2 = x_c + w_half
  288. boxes_exp3 = y_c + h_half
  289. boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
  290. return boxes_exp
  291. # the next two functions should be merged inside Masker
  292. # but are kept here for the moment while we need them
  293. # temporarily for paste_mask_in_image
  294. def expand_boxes(boxes, scale):
  295. # type: (Tensor, float) -> Tensor
  296. if torchvision._is_tracing():
  297. return _onnx_expand_boxes(boxes, scale)
  298. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  299. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  300. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  301. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  302. w_half *= scale
  303. h_half *= scale
  304. boxes_exp = torch.zeros_like(boxes)
  305. boxes_exp[:, 0] = x_c - w_half
  306. boxes_exp[:, 2] = x_c + w_half
  307. boxes_exp[:, 1] = y_c - h_half
  308. boxes_exp[:, 3] = y_c + h_half
  309. return boxes_exp
  310. @torch.jit.unused
  311. def expand_masks_tracing_scale(M, padding):
  312. # type: (int, int) -> float
  313. return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
  314. def expand_masks(mask, padding):
  315. # type: (Tensor, int) -> Tuple[Tensor, float]
  316. M = mask.shape[-1]
  317. if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
  318. scale = expand_masks_tracing_scale(M, padding)
  319. else:
  320. scale = float(M + 2 * padding) / M
  321. padded_mask = F.pad(mask, (padding,) * 4)
  322. return padded_mask, scale
  323. def paste_mask_in_image(mask, box, im_h, im_w):
  324. # type: (Tensor, Tensor, int, int) -> Tensor
  325. TO_REMOVE = 1
  326. w = int(box[2] - box[0] + TO_REMOVE)
  327. h = int(box[3] - box[1] + TO_REMOVE)
  328. w = max(w, 1)
  329. h = max(h, 1)
  330. # Set shape to [batchxCxHxW]
  331. mask = mask.expand((1, 1, -1, -1))
  332. # Resize mask
  333. mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
  334. mask = mask[0][0]
  335. im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
  336. x_0 = max(box[0], 0)
  337. x_1 = min(box[2] + 1, im_w)
  338. y_0 = max(box[1], 0)
  339. y_1 = min(box[3] + 1, im_h)
  340. im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
  341. return im_mask
  342. def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
  343. one = torch.ones(1, dtype=torch.int64)
  344. zero = torch.zeros(1, dtype=torch.int64)
  345. w = box[2] - box[0] + one
  346. h = box[3] - box[1] + one
  347. w = torch.max(torch.cat((w, one)))
  348. h = torch.max(torch.cat((h, one)))
  349. # Set shape to [batchxCxHxW]
  350. mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
  351. # Resize mask
  352. mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
  353. mask = mask[0][0]
  354. x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
  355. x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
  356. y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
  357. y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
  358. unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
  359. # TODO : replace below with a dynamic padding when support is added in ONNX
  360. # pad y
  361. zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
  362. zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
  363. concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
  364. # pad x
  365. zeros_x0 = torch.zeros(concat_0.size(0), x_0)
  366. zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
  367. im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
  368. return im_mask
  369. @torch.jit._script_if_tracing
  370. def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
  371. res_append = torch.zeros(0, im_h, im_w)
  372. for i in range(masks.size(0)):
  373. mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
  374. mask_res = mask_res.unsqueeze(0)
  375. res_append = torch.cat((res_append, mask_res))
  376. return res_append
  377. def paste_masks_in_image(masks, boxes, img_shape, padding=1):
  378. # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
  379. masks, scale = expand_masks(masks, padding=padding)
  380. boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
  381. im_h, im_w = img_shape
  382. if torchvision._is_tracing():
  383. return _onnx_paste_masks_in_image_loop(
  384. masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
  385. )[:, None]
  386. res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
  387. if len(res) > 0:
  388. ret = torch.stack(res, dim=0)[:, None]
  389. else:
  390. ret = masks.new_empty((0, 1, im_h, im_w))
  391. return ret
  392. class RoIHeads(nn.Module):
  393. __annotations__ = {
  394. "box_coder": det_utils.BoxCoder,
  395. "proposal_matcher": det_utils.Matcher,
  396. "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
  397. }
  398. def __init__(
  399. self,
  400. box_roi_pool,
  401. box_head,
  402. box_predictor,
  403. # Faster R-CNN training
  404. fg_iou_thresh,
  405. bg_iou_thresh,
  406. batch_size_per_image,
  407. positive_fraction,
  408. bbox_reg_weights,
  409. # Faster R-CNN inference
  410. score_thresh,
  411. nms_thresh,
  412. detections_per_img,
  413. # Mask
  414. mask_roi_pool=None,
  415. mask_head=None,
  416. mask_predictor=None,
  417. keypoint_roi_pool=None,
  418. keypoint_head=None,
  419. keypoint_predictor=None,
  420. ):
  421. super().__init__()
  422. self.box_similarity = box_ops.box_iou
  423. # assign ground-truth boxes for each proposal
  424. self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
  425. self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
  426. if bbox_reg_weights is None:
  427. bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
  428. self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
  429. self.box_roi_pool = box_roi_pool
  430. self.box_head = box_head
  431. self.box_predictor = box_predictor
  432. self.score_thresh = score_thresh
  433. self.nms_thresh = nms_thresh
  434. self.detections_per_img = detections_per_img
  435. self.mask_roi_pool = mask_roi_pool
  436. self.mask_head = mask_head
  437. self.mask_predictor = mask_predictor
  438. self.keypoint_roi_pool = keypoint_roi_pool
  439. self.keypoint_head = keypoint_head
  440. self.keypoint_predictor = keypoint_predictor
  441. def has_mask(self):
  442. if self.mask_roi_pool is None:
  443. return False
  444. if self.mask_head is None:
  445. return False
  446. if self.mask_predictor is None:
  447. return False
  448. return True
  449. def has_keypoint(self):
  450. if self.keypoint_roi_pool is None:
  451. return False
  452. if self.keypoint_head is None:
  453. return False
  454. if self.keypoint_predictor is None:
  455. return False
  456. return True
  457. def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
  458. # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  459. matched_idxs = []
  460. labels = []
  461. for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
  462. if gt_boxes_in_image.numel() == 0:
  463. # Background image
  464. device = proposals_in_image.device
  465. clamped_matched_idxs_in_image = torch.zeros(
  466. (proposals_in_image.shape[0],), dtype=torch.int64, device=device
  467. )
  468. labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
  469. else:
  470. # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
  471. match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
  472. matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
  473. clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
  474. labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
  475. labels_in_image = labels_in_image.to(dtype=torch.int64)
  476. # Label background (below the low threshold)
  477. bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
  478. labels_in_image[bg_inds] = 0
  479. # Label ignore proposals (between low and high thresholds)
  480. ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
  481. labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
  482. matched_idxs.append(clamped_matched_idxs_in_image)
  483. labels.append(labels_in_image)
  484. return matched_idxs, labels
  485. def subsample(self, labels):
  486. # type: (List[Tensor]) -> List[Tensor]
  487. sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  488. sampled_inds = []
  489. for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
  490. img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
  491. sampled_inds.append(img_sampled_inds)
  492. return sampled_inds
  493. def add_gt_proposals(self, proposals, gt_boxes):
  494. # type: (List[Tensor], List[Tensor]) -> List[Tensor]
  495. proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
  496. return proposals
  497. def check_targets(self, targets):
  498. # type: (Optional[List[Dict[str, Tensor]]]) -> None
  499. if targets is None:
  500. raise ValueError("targets should not be None")
  501. if not all(["boxes" in t for t in targets]):
  502. raise ValueError("Every element of targets should have a boxes key")
  503. if not all(["labels" in t for t in targets]):
  504. raise ValueError("Every element of targets should have a labels key")
  505. if self.has_mask():
  506. if not all(["masks" in t for t in targets]):
  507. raise ValueError("Every element of targets should have a masks key")
  508. def select_training_samples(
  509. self,
  510. proposals, # type: List[Tensor]
  511. targets, # type: Optional[List[Dict[str, Tensor]]]
  512. ):
  513. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
  514. self.check_targets(targets)
  515. if targets is None:
  516. raise ValueError("targets should not be None")
  517. dtype = proposals[0].dtype
  518. device = proposals[0].device
  519. gt_boxes = [t["boxes"].to(dtype) for t in targets]
  520. gt_labels = [t["labels"] for t in targets]
  521. # append ground-truth bboxes to propos
  522. proposals = self.add_gt_proposals(proposals, gt_boxes)
  523. # get matching gt indices for each proposal
  524. matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
  525. # sample a fixed proportion of positive-negative proposals
  526. sampled_inds = self.subsample(labels)
  527. matched_gt_boxes = []
  528. num_images = len(proposals)
  529. for img_id in range(num_images):
  530. img_sampled_inds = sampled_inds[img_id]
  531. proposals[img_id] = proposals[img_id][img_sampled_inds]
  532. labels[img_id] = labels[img_id][img_sampled_inds]
  533. matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
  534. gt_boxes_in_image = gt_boxes[img_id]
  535. if gt_boxes_in_image.numel() == 0:
  536. gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
  537. matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
  538. regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
  539. return proposals, matched_idxs, labels, regression_targets
  540. def postprocess_detections(
  541. self,
  542. class_logits, # type: Tensor
  543. box_regression, # type: Tensor
  544. proposals, # type: List[Tensor]
  545. image_shapes, # type: List[Tuple[int, int]]
  546. ):
  547. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
  548. device = class_logits.device
  549. num_classes = class_logits.shape[-1]
  550. boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
  551. pred_boxes = self.box_coder.decode(box_regression, proposals)
  552. pred_scores = F.softmax(class_logits, -1)
  553. pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
  554. pred_scores_list = pred_scores.split(boxes_per_image, 0)
  555. all_boxes = []
  556. all_scores = []
  557. all_labels = []
  558. for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
  559. boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
  560. # create labels for each prediction
  561. labels = torch.arange(num_classes, device=device)
  562. labels = labels.view(1, -1).expand_as(scores)
  563. # remove predictions with the background label
  564. boxes = boxes[:, 1:]
  565. scores = scores[:, 1:]
  566. labels = labels[:, 1:]
  567. # batch everything, by making every class prediction be a separate instance
  568. boxes = boxes.reshape(-1, 4)
  569. scores = scores.reshape(-1)
  570. labels = labels.reshape(-1)
  571. # remove low scoring boxes
  572. inds = torch.where(scores > self.score_thresh)[0]
  573. boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
  574. # remove empty boxes
  575. keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
  576. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  577. # non-maximum suppression, independently done per class
  578. keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
  579. # keep only topk scoring predictions
  580. keep = keep[: self.detections_per_img]
  581. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  582. all_boxes.append(boxes)
  583. all_scores.append(scores)
  584. all_labels.append(labels)
  585. return all_boxes, all_scores, all_labels
  586. def forward(
  587. self,
  588. features, # type: Dict[str, Tensor]
  589. proposals, # type: List[Tensor]
  590. image_shapes, # type: List[Tuple[int, int]]
  591. targets=None, # type: Optional[List[Dict[str, Tensor]]]
  592. ):
  593. # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
  594. """
  595. Args:
  596. features (List[Tensor])
  597. proposals (List[Tensor[N, 4]])
  598. image_shapes (List[Tuple[H, W]])
  599. targets (List[Dict])
  600. """
  601. if targets is not None:
  602. for t in targets:
  603. # TODO: https://github.com/pytorch/pytorch/issues/26731
  604. floating_point_types = (torch.float, torch.double, torch.half)
  605. if not t["boxes"].dtype in floating_point_types:
  606. raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
  607. if not t["labels"].dtype == torch.int64:
  608. raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
  609. if self.has_keypoint():
  610. if not t["keypoints"].dtype == torch.float32:
  611. raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
  612. if self.training:
  613. proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  614. else:
  615. labels = None
  616. regression_targets = None
  617. matched_idxs = None
  618. box_features = self.box_roi_pool(features, proposals, image_shapes)
  619. box_features = self.box_head(box_features)
  620. class_logits, box_regression = self.box_predictor(box_features)
  621. result: List[Dict[str, torch.Tensor]] = []
  622. losses = {}
  623. if self.training:
  624. if labels is None:
  625. raise ValueError("labels cannot be None")
  626. if regression_targets is None:
  627. raise ValueError("regression_targets cannot be None")
  628. loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  629. losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  630. else:
  631. boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
  632. num_images = len(boxes)
  633. for i in range(num_images):
  634. result.append(
  635. {
  636. "boxes": boxes[i],
  637. "labels": labels[i],
  638. "scores": scores[i],
  639. }
  640. )
  641. if self.has_mask():
  642. mask_proposals = [p["boxes"] for p in result]
  643. if self.training:
  644. if matched_idxs is None:
  645. raise ValueError("if in training, matched_idxs should not be None")
  646. # during training, only focus on positive boxes
  647. num_images = len(proposals)
  648. mask_proposals = []
  649. pos_matched_idxs = []
  650. for img_id in range(num_images):
  651. pos = torch.where(labels[img_id] > 0)[0]
  652. mask_proposals.append(proposals[img_id][pos])
  653. pos_matched_idxs.append(matched_idxs[img_id][pos])
  654. else:
  655. pos_matched_idxs = None
  656. if self.mask_roi_pool is not None:
  657. mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
  658. mask_features = self.mask_head(mask_features)
  659. mask_logits = self.mask_predictor(mask_features)
  660. else:
  661. raise Exception("Expected mask_roi_pool to be not None")
  662. loss_mask = {}
  663. if self.training:
  664. if targets is None or pos_matched_idxs is None or mask_logits is None:
  665. raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
  666. gt_masks = [t["masks"] for t in targets]
  667. gt_labels = [t["labels"] for t in targets]
  668. rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
  669. loss_mask = {"loss_mask": rcnn_loss_mask}
  670. else:
  671. labels = [r["labels"] for r in result]
  672. masks_probs = maskrcnn_inference(mask_logits, labels)
  673. for mask_prob, r in zip(masks_probs, result):
  674. r["masks"] = mask_prob
  675. losses.update(loss_mask)
  676. # keep none checks in if conditional so torchscript will conditionally
  677. # compile each branch
  678. if (
  679. self.keypoint_roi_pool is not None
  680. and self.keypoint_head is not None
  681. and self.keypoint_predictor is not None
  682. ):
  683. keypoint_proposals = [p["boxes"] for p in result]
  684. if self.training:
  685. # during training, only focus on positive boxes
  686. num_images = len(proposals)
  687. keypoint_proposals = []
  688. pos_matched_idxs = []
  689. if matched_idxs is None:
  690. raise ValueError("if in trainning, matched_idxs should not be None")
  691. for img_id in range(num_images):
  692. pos = torch.where(labels[img_id] > 0)[0]
  693. keypoint_proposals.append(proposals[img_id][pos])
  694. pos_matched_idxs.append(matched_idxs[img_id][pos])
  695. else:
  696. pos_matched_idxs = None
  697. keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
  698. keypoint_features = self.keypoint_head(keypoint_features)
  699. keypoint_logits = self.keypoint_predictor(keypoint_features)
  700. loss_keypoint = {}
  701. if self.training:
  702. if targets is None or pos_matched_idxs is None:
  703. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  704. gt_keypoints = [t["keypoints"] for t in targets]
  705. rcnn_loss_keypoint = keypointrcnn_loss(
  706. keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
  707. )
  708. loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
  709. else:
  710. if keypoint_logits is None or keypoint_proposals is None:
  711. raise ValueError(
  712. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  713. )
  714. keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
  715. for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
  716. r["keypoints"] = keypoint_prob
  717. r["keypoints_scores"] = kps
  718. losses.update(loss_keypoint)
  719. return result, losses