base_detection_net.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. """
  2. Implements the Generalized R-CNN framework
  3. """
  4. import warnings
  5. from collections import OrderedDict
  6. from typing import Dict, List, Optional, Tuple, Union
  7. import torch
  8. from torch import nn, Tensor
  9. from libs.vision_libs.utils import _log_api_usage_once
  10. from models.base.base_model import BaseModel
  11. class BaseDetectionNet(BaseModel):
  12. """
  13. Main class for Generalized R-CNN.
  14. Args:
  15. backbone (nn.Module):
  16. rpn (nn.Module):
  17. roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
  18. detections / masks from it.
  19. transform (nn.Module): performs the data transformation from the inputs to feed into
  20. the model
  21. """
  22. def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
  23. super().__init__()
  24. _log_api_usage_once(self)
  25. self.transform = transform
  26. self.backbone = backbone
  27. self.rpn = rpn
  28. self.roi_heads = roi_heads
  29. # used only on torchscript mode
  30. self._has_warned = False
  31. @torch.jit.unused
  32. def eager_outputs(self, losses, detections):
  33. # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
  34. if self.training:
  35. return losses
  36. return detections
  37. def forward(self, images, targets=None):
  38. # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
  39. """
  40. Args:
  41. images (list[Tensor]): images to be processed
  42. targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
  43. Returns:
  44. result (list[BoxList] or dict[Tensor]): the output from the model.
  45. During training, it returns a dict[Tensor] which contains the losses.
  46. During testing, it returns list[BoxList] contains additional fields
  47. like `scores`, `labels` and `mask` (for Mask R-CNN models).
  48. """
  49. if self.training:
  50. if targets is None:
  51. torch._assert(False, "targets should not be none when in training mode")
  52. else:
  53. for target in targets:
  54. boxes = target["boxes"]
  55. if isinstance(boxes, torch.Tensor):
  56. torch._assert(
  57. len(boxes.shape) == 2 and boxes.shape[-1] == 4,
  58. f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
  59. )
  60. else:
  61. torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
  62. original_image_sizes: List[Tuple[int, int]] = []
  63. for img in images:
  64. val = img.shape[-2:]
  65. torch._assert(
  66. len(val) == 2,
  67. f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
  68. )
  69. original_image_sizes.append((val[0], val[1]))
  70. images, targets = self.transform(images, targets)
  71. # Check for degenerate boxes
  72. # TODO: Move this to a function
  73. if targets is not None:
  74. for target_idx, target in enumerate(targets):
  75. boxes = target["boxes"]
  76. degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
  77. if degenerate_boxes.any():
  78. # print the first degenerate box
  79. bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
  80. degen_bb: List[float] = boxes[bb_idx].tolist()
  81. torch._assert(
  82. False,
  83. "All bounding boxes should have positive height and width."
  84. f" Found invalid box {degen_bb} for target at index {target_idx}.",
  85. )
  86. features = self.backbone(images.tensors)
  87. if isinstance(features, torch.Tensor):
  88. features = OrderedDict([("0", features)])
  89. proposals, proposal_losses = self.rpn(images, features, targets)
  90. detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
  91. detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
  92. # ->multi task head
  93. # ->learner,->vectorize
  94. losses = {}
  95. losses.update(detector_losses)
  96. losses.update(proposal_losses)
  97. if torch.jit.is_scripting():
  98. if not self._has_warned:
  99. warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
  100. self._has_warned = True
  101. return losses, detections
  102. else:
  103. return self.eager_outputs(losses, detections)