base_detection_net.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. """
  2. Implements the Generalized R-CNN framework
  3. """
  4. import warnings
  5. from collections import OrderedDict
  6. from typing import Dict, List, Optional, Tuple, Union
  7. import torch
  8. from torch import nn, Tensor
  9. from libs.vision_libs.utils import _log_api_usage_once
  10. from models.base.base_model import BaseModel
  11. import matplotlib.pyplot as plt
  12. class BaseDetectionNet(BaseModel):
  13. """
  14. Main class for Generalized R-CNN.
  15. Args:
  16. backbone (nn.Module):
  17. rpn (nn.Module):
  18. roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
  19. detections / masks from it.
  20. transform (nn.Module): performs the data transformation from the inputs to feed into
  21. the model
  22. """
  23. def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
  24. super().__init__()
  25. _log_api_usage_once(self)
  26. self.transform = transform
  27. self.backbone = backbone
  28. self.rpn = rpn
  29. self.roi_heads = roi_heads
  30. # used only on torchscript mode
  31. self._has_warned = False
  32. @torch.jit.unused
  33. def eager_outputs(self, losses, detections,targets=None):
  34. # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
  35. if self.training:
  36. return losses
  37. else:
  38. if targets is not None:
  39. print(f'returned (detections,losses):{losses}')
  40. return detections,losses
  41. else:
  42. return detections
  43. # return detections,losses
  44. def forward(self, images, targets=None):
  45. # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
  46. """
  47. Args:
  48. images (list[Tensor]): images to be processed
  49. targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
  50. Returns:
  51. result (list[BoxList] or dict[Tensor]): the output from the model.
  52. During training, it returns a dict[Tensor] which contains the losses.
  53. During testing, it returns list[BoxList] contains additional fields
  54. like `scores`, `labels` and `mask` (for Mask R-CNN models).
  55. """
  56. if self.training:
  57. if targets is None:
  58. torch._assert(False, "targets should not be none when in training mode")
  59. else:
  60. for target in targets:
  61. boxes = target["boxes"]
  62. if isinstance(boxes, torch.Tensor):
  63. torch._assert(
  64. len(boxes.shape) == 2 and boxes.shape[-1] == 4,
  65. f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
  66. )
  67. else:
  68. torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
  69. original_image_sizes: List[Tuple[int, int]] = []
  70. for img in images:
  71. val = img.shape[-2:]
  72. torch._assert(
  73. len(val) == 2,
  74. f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
  75. )
  76. original_image_sizes.append((val[0], val[1]))
  77. images, targets = self.transform(images, targets)
  78. # Check for degenerate boxes
  79. # TODO: Move this to a function
  80. if targets is not None:
  81. for target_idx, target in enumerate(targets):
  82. boxes = target["boxes"]
  83. degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
  84. if degenerate_boxes.any():
  85. # print the first degenerate box
  86. bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
  87. degen_bb: List[float] = boxes[bb_idx].tolist()
  88. torch._assert(
  89. False,
  90. "All bounding boxes should have positive height and width."
  91. f" Found invalid box {degen_bb} for target at index {target_idx}.",
  92. )
  93. features = self.backbone(images.tensors)
  94. if isinstance(features, torch.Tensor):
  95. features = OrderedDict([("0", features)])
  96. proposals, proposal_losses = self.rpn(images, features, targets)
  97. # print(f'proposals:{proposals[0].shape}')
  98. detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
  99. detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
  100. # ->multi task head
  101. # ->learner,->vectorize
  102. losses = {}
  103. losses.update(detector_losses)
  104. losses.update(proposal_losses)
  105. if torch.jit.is_scripting():
  106. if not self._has_warned:
  107. warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
  108. self._has_warned = True
  109. return losses, detections
  110. else:
  111. return self.eager_outputs(losses, detections,targets)