base_detection_net.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. """
  2. Implements the Generalized R-CNN framework
  3. """
  4. import warnings
  5. from collections import OrderedDict
  6. from typing import Dict, List, Optional, Tuple, Union
  7. import torch
  8. from torch import nn, Tensor
  9. from libs.vision_libs.utils import _log_api_usage_once
  10. from models.base.base_model import BaseModel
  11. from models.line_detect.trainer import Trainer
  12. class BaseDetectionNet(BaseModel):
  13. """
  14. Main class for Generalized R-CNN.
  15. Args:
  16. backbone (nn.Module):
  17. rpn (nn.Module):
  18. roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
  19. detections / masks from it.
  20. transform (nn.Module): performs the data transformation from the inputs to feed into
  21. the model
  22. """
  23. def train(self, cfg):
  24. pass
  25. def get_loss(self, Loss, results, inputs, device):
  26. pass
  27. def get_optimizer(self, cfg_pipeline):
  28. pass
  29. def preprocess(self, cfg_pipeline):
  30. pass
  31. def transform(self, cfg_pipeline):
  32. pass
  33. def inference_begin(self, data):
  34. pass
  35. def inference_preprocess(self):
  36. pass
  37. def inference_end(self, inputs, results):
  38. pass
  39. def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
  40. super().__init__()
  41. _log_api_usage_once(self)
  42. self.transform = transform
  43. self.backbone = backbone
  44. self.rpn = rpn
  45. self.roi_heads = roi_heads
  46. # used only on torchscript mode
  47. self._has_warned = False
  48. @torch.jit.unused
  49. def eager_outputs(self, losses, detections):
  50. # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
  51. if self.training:
  52. return losses
  53. return detections
  54. def train(self, cfg):
  55. self.__trainer.train(self, "test")
  56. def load_weight(self, pt_path):
  57. state_dict = torch.load(pt_path)
  58. self.load_state_dict(state_dict)
  59. def forward(self, images, targets=None):
  60. # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
  61. """
  62. Args:
  63. images (list[Tensor]): images to be processed
  64. targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
  65. Returns:
  66. result (list[BoxList] or dict[Tensor]): the output from the model.
  67. During training, it returns a dict[Tensor] which contains the losses.
  68. During testing, it returns list[BoxList] contains additional fields
  69. like `scores`, `labels` and `mask` (for Mask R-CNN models).
  70. """
  71. if self.training:
  72. if targets is None:
  73. torch._assert(False, "targets should not be none when in training mode")
  74. else:
  75. for target in targets:
  76. boxes = target["boxes"]
  77. if isinstance(boxes, torch.Tensor):
  78. torch._assert(
  79. len(boxes.shape) == 2 and boxes.shape[-1] == 4,
  80. f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
  81. )
  82. else:
  83. torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
  84. original_image_sizes: List[Tuple[int, int]] = []
  85. for img in images:
  86. val = img.shape[-2:]
  87. torch._assert(
  88. len(val) == 2,
  89. f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
  90. )
  91. original_image_sizes.append((val[0], val[1]))
  92. images, targets = self.transform(images, targets)
  93. # Check for degenerate boxes
  94. # TODO: Move this to a function
  95. if targets is not None:
  96. for target_idx, target in enumerate(targets):
  97. boxes = target["boxes"]
  98. degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
  99. if degenerate_boxes.any():
  100. # print the first degenerate box
  101. bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
  102. degen_bb: List[float] = boxes[bb_idx].tolist()
  103. torch._assert(
  104. False,
  105. "All bounding boxes should have positive height and width."
  106. f" Found invalid box {degen_bb} for target at index {target_idx}.",
  107. )
  108. features = self.backbone(images.tensors)
  109. if isinstance(features, torch.Tensor):
  110. features = OrderedDict([("0", features)])
  111. proposals, proposal_losses = self.rpn(images, features, targets)
  112. detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
  113. detections = self.transform.postprocess(detections, images.image_sizes,
  114. original_image_sizes) # type: ignore[operator]
  115. # ->multi task head
  116. # ->learner,->vectorize
  117. losses = {}
  118. losses.update(detector_losses)
  119. losses.update(proposal_losses)
  120. if torch.jit.is_scripting():
  121. if not self._has_warned:
  122. warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
  123. self._has_warned = True
  124. return losses, detections
  125. else:
  126. return self.eager_outputs(losses, detections)