rpn.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. from typing import Dict, List, Optional, Tuple
  2. import torch
  3. from torch import nn, Tensor
  4. from torch.nn import functional as F
  5. from ...ops import boxes as box_ops, Conv2dNormActivation
  6. from . import _utils as det_utils
  7. # Import AnchorGenerator to keep compatibility.
  8. from .anchor_utils import AnchorGenerator # noqa: 401
  9. from .image_list import ImageList
  10. class RPNHead(nn.Module):
  11. """
  12. Adds a simple RPN Head with classification and regression heads
  13. Args:
  14. in_channels (int): number of channels of the input feature
  15. num_anchors (int): number of anchors to be predicted
  16. conv_depth (int, optional): number of convolutions
  17. """
  18. _version = 2
  19. def __init__(self, in_channels: int, num_anchors: int, conv_depth=1) -> None:
  20. super().__init__()
  21. convs = []
  22. for _ in range(conv_depth):
  23. convs.append(Conv2dNormActivation(in_channels, in_channels, kernel_size=3, norm_layer=None))
  24. self.conv = nn.Sequential(*convs)
  25. self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
  26. self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
  27. self.line_pred=nn.Conv2d(in_channels, num_anchors * 2, kernel_size=1, stride=1)
  28. for layer in self.modules():
  29. if isinstance(layer, nn.Conv2d):
  30. torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
  31. if layer.bias is not None:
  32. torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
  33. def _load_from_state_dict(
  34. self,
  35. state_dict,
  36. prefix,
  37. local_metadata,
  38. strict,
  39. missing_keys,
  40. unexpected_keys,
  41. error_msgs,
  42. ):
  43. version = local_metadata.get("version", None)
  44. if version is None or version < 2:
  45. for type in ["weight", "bias"]:
  46. old_key = f"{prefix}conv.{type}"
  47. new_key = f"{prefix}conv.0.0.{type}"
  48. if old_key in state_dict:
  49. state_dict[new_key] = state_dict.pop(old_key)
  50. super()._load_from_state_dict(
  51. state_dict,
  52. prefix,
  53. local_metadata,
  54. strict,
  55. missing_keys,
  56. unexpected_keys,
  57. error_msgs,
  58. )
  59. def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
  60. logits = []
  61. bbox_reg = []
  62. for feature in x:
  63. t = self.conv(feature)
  64. logits.append(self.cls_logits(t))
  65. bbox_reg.append(self.bbox_pred(t))
  66. return logits, bbox_reg
  67. def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
  68. layer = layer.view(N, -1, C, H, W)
  69. layer = layer.permute(0, 3, 4, 1, 2)
  70. layer = layer.reshape(N, -1, C)
  71. return layer
  72. def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]:
  73. box_cls_flattened = []
  74. box_regression_flattened = []
  75. # for each feature level, permute the outputs to make them be in the
  76. # same format as the labels. Note that the labels are computed for
  77. # all feature levels concatenated, so we keep the same representation
  78. # for the objectness and the box_regression
  79. for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
  80. N, AxC, H, W = box_cls_per_level.shape
  81. Ax4 = box_regression_per_level.shape[1]
  82. A = Ax4 // 4
  83. C = AxC // A
  84. box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
  85. box_cls_flattened.append(box_cls_per_level)
  86. box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
  87. box_regression_flattened.append(box_regression_per_level)
  88. # concatenate on the first dimension (representing the feature levels), to
  89. # take into account the way the labels were generated (with all feature maps
  90. # being concatenated as well)
  91. box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
  92. box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
  93. return box_cls, box_regression
  94. class RegionProposalNetwork(torch.nn.Module):
  95. """
  96. Implements Region Proposal Network (RPN).
  97. Args:
  98. anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  99. maps.
  100. head (nn.Module): module that computes the objectness and regression deltas
  101. fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  102. considered as positive during training of the RPN.
  103. bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  104. considered as negative during training of the RPN.
  105. batch_size_per_image (int): number of anchors that are sampled during training of the RPN
  106. for computing the loss
  107. positive_fraction (float): proportion of positive anchors in a mini-batch during training
  108. of the RPN
  109. pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
  110. contain two fields: training and testing, to allow for different values depending
  111. on training or evaluation
  112. post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
  113. contain two fields: training and testing, to allow for different values depending
  114. on training or evaluation
  115. nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
  116. """
  117. __annotations__ = {
  118. "box_coder": det_utils.BoxCoder,
  119. "proposal_matcher": det_utils.Matcher,
  120. "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
  121. }
  122. def __init__(
  123. self,
  124. anchor_generator: AnchorGenerator,
  125. head: nn.Module,
  126. # Faster-RCNN Training
  127. fg_iou_thresh: float,
  128. bg_iou_thresh: float,
  129. batch_size_per_image: int,
  130. positive_fraction: float,
  131. # Faster-RCNN Inference
  132. pre_nms_top_n: Dict[str, int],
  133. post_nms_top_n: Dict[str, int],
  134. nms_thresh: float,
  135. score_thresh: float = 0.0,
  136. ) -> None:
  137. super().__init__()
  138. self.anchor_generator = anchor_generator
  139. self.head = head
  140. self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
  141. # used during training
  142. self.box_similarity = box_ops.box_iou
  143. self.proposal_matcher = det_utils.Matcher(
  144. fg_iou_thresh,
  145. bg_iou_thresh,
  146. allow_low_quality_matches=True,
  147. )
  148. self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
  149. # used during testing
  150. self._pre_nms_top_n = pre_nms_top_n
  151. self._post_nms_top_n = post_nms_top_n
  152. self.nms_thresh = nms_thresh
  153. self.score_thresh = score_thresh
  154. self.min_size = 1e-3
  155. def pre_nms_top_n(self) -> int:
  156. if self.training:
  157. return self._pre_nms_top_n["training"]
  158. return self._pre_nms_top_n["testing"]
  159. def post_nms_top_n(self) -> int:
  160. if self.training:
  161. return self._post_nms_top_n["training"]
  162. return self._post_nms_top_n["testing"]
  163. def assign_targets_to_anchors(
  164. self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]
  165. ) -> Tuple[List[Tensor], List[Tensor]]:
  166. labels = []
  167. matched_gt_boxes = []
  168. for anchors_per_image, targets_per_image in zip(anchors, targets):
  169. gt_boxes = targets_per_image["boxes"]
  170. if gt_boxes.numel() == 0:
  171. # Background image (negative example)
  172. device = anchors_per_image.device
  173. matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
  174. labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
  175. else:
  176. match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
  177. matched_idxs = self.proposal_matcher(match_quality_matrix)
  178. # get the targets corresponding GT for each proposal
  179. # NB: need to clamp the indices because we can have a single
  180. # GT in the image, and matched_idxs can be -2, which goes
  181. # out of bounds
  182. matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
  183. labels_per_image = matched_idxs >= 0
  184. labels_per_image = labels_per_image.to(dtype=torch.float32)
  185. # Background (negative examples)
  186. bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
  187. labels_per_image[bg_indices] = 0.0
  188. # discard indices that are between thresholds
  189. inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
  190. labels_per_image[inds_to_discard] = -1.0
  191. labels.append(labels_per_image)
  192. matched_gt_boxes.append(matched_gt_boxes_per_image)
  193. return labels, matched_gt_boxes
  194. def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor:
  195. r = []
  196. offset = 0
  197. for ob in objectness.split(num_anchors_per_level, 1):
  198. num_anchors = ob.shape[1]
  199. pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
  200. _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
  201. r.append(top_n_idx + offset)
  202. offset += num_anchors
  203. return torch.cat(r, dim=1)
  204. def filter_proposals(
  205. self,
  206. proposals: Tensor,
  207. objectness: Tensor,
  208. image_shapes: List[Tuple[int, int]],
  209. num_anchors_per_level: List[int],
  210. ) -> Tuple[List[Tensor], List[Tensor]]:
  211. num_images = proposals.shape[0]
  212. device = proposals.device
  213. # do not backprop through objectness
  214. objectness = objectness.detach()
  215. objectness = objectness.reshape(num_images, -1)
  216. levels = [
  217. torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)
  218. ]
  219. levels = torch.cat(levels, 0)
  220. levels = levels.reshape(1, -1).expand_as(objectness)
  221. # select top_n boxes independently per level before applying nms
  222. top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
  223. image_range = torch.arange(num_images, device=device)
  224. batch_idx = image_range[:, None]
  225. objectness = objectness[batch_idx, top_n_idx]
  226. levels = levels[batch_idx, top_n_idx]
  227. proposals = proposals[batch_idx, top_n_idx]
  228. objectness_prob = torch.sigmoid(objectness)
  229. final_boxes = []
  230. final_scores = []
  231. for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
  232. boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
  233. # remove small boxes
  234. keep = box_ops.remove_small_boxes(boxes, self.min_size)
  235. boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
  236. # remove low scoring boxes
  237. # use >= for Backwards compatibility
  238. keep = torch.where(scores >= self.score_thresh)[0]
  239. boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
  240. # non-maximum suppression, independently done per level
  241. keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
  242. # keep only topk scoring predictions
  243. keep = keep[: self.post_nms_top_n()]
  244. boxes, scores = boxes[keep], scores[keep]
  245. final_boxes.append(boxes)
  246. final_scores.append(scores)
  247. return final_boxes, final_scores
  248. def compute_loss(
  249. self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]
  250. ) -> Tuple[Tensor, Tensor]:
  251. """
  252. Args:
  253. objectness (Tensor)
  254. pred_bbox_deltas (Tensor)
  255. labels (List[Tensor])
  256. regression_targets (List[Tensor])
  257. Returns:
  258. objectness_loss (Tensor)
  259. box_loss (Tensor)
  260. """
  261. sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  262. sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
  263. sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
  264. sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
  265. objectness = objectness.flatten()
  266. labels = torch.cat(labels, dim=0)
  267. regression_targets = torch.cat(regression_targets, dim=0)
  268. box_loss = F.smooth_l1_loss(
  269. pred_bbox_deltas[sampled_pos_inds],
  270. regression_targets[sampled_pos_inds],
  271. beta=1 / 9,
  272. reduction="sum",
  273. ) / (sampled_inds.numel())
  274. objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])
  275. return objectness_loss, box_loss
  276. def forward(
  277. self,
  278. images: ImageList,
  279. features: Dict[str, Tensor],
  280. targets: Optional[List[Dict[str, Tensor]]] = None,
  281. ) -> Tuple[List[Tensor], Dict[str, Tensor]]:
  282. """
  283. Args:
  284. images (ImageList): images for which we want to compute the predictions
  285. features (Dict[str, Tensor]): features computed from the images that are
  286. used for computing the predictions. Each tensor in the list
  287. correspond to different feature levels
  288. targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
  289. If provided, each element in the dict should contain a field `boxes`,
  290. with the locations of the ground-truth boxes.
  291. Returns:
  292. boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
  293. image.
  294. losses (Dict[str, Tensor]): the losses for the model during training. During
  295. testing, it is an empty dict.
  296. """
  297. # RPN uses all feature maps that are available
  298. features = list(features.values())
  299. objectness, pred_bbox_deltas = self.head(features)
  300. # for obj in objectness:
  301. # print(f'objectness:{obj.shape}')
  302. # for pred_bbox in pred_bbox_deltas:
  303. # print(f'pred_bbox:{pred_bbox.shape}')
  304. anchors = self.anchor_generator(images, features)
  305. num_images = len(anchors)
  306. num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
  307. num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
  308. objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
  309. # apply pred_bbox_deltas to anchors to obtain the decoded proposals
  310. # note that we detach the deltas because Faster R-CNN do not backprop through
  311. # the proposals
  312. proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
  313. # print(f'box_coder.decode proposals:{proposals.shape}')
  314. proposals = proposals.view(num_images, -1, 4)
  315. boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
  316. # print(f'boxes:{boxes[0].shape},scores:{scores[0].shape}')
  317. #
  318. # lines=self.lines_generator(features,300)
  319. #
  320. # # 合并所有线段为一个 Tensor(假设 batch_size=2)
  321. # lines_all = torch.cat(lines, dim=0) # [Total_Lines, 4]
  322. #
  323. # # 过滤出在 boxes 内的线段
  324. # lines =self.filter_lines_inside_boxes(lines_all, boxes)
  325. # print(f'filter_lines:{lines}')
  326. losses = {}
  327. if self.training:
  328. if targets is None:
  329. raise ValueError("targets should not be None")
  330. labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
  331. regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
  332. loss_objectness, loss_rpn_box_reg = self.compute_loss(
  333. objectness, pred_bbox_deltas, labels, regression_targets
  334. )
  335. losses = {
  336. "loss_objectness": loss_objectness,
  337. "loss_rpn_box_reg": loss_rpn_box_reg,
  338. }
  339. # print(f'boxes:{boxes[0].shape}')
  340. return boxes,losses
  341. def lines_generator(self, features: torch.Tensor, topk=300):
  342. """
  343. Args:
  344. features (Tensor): shape [B, C, H, W], 其中 C >= 3
  345. - features[:, 0]: jmap (junction map)
  346. - features[:, 1:3]: joff (offsets in x and y)
  347. topk (int): 提取热度最高的前 K 个点
  348. Returns:
  349. lines_batch (List[Tensor]): 每个元素是一个 [N, 4] 的 Tensor 表示该图像中的线段
  350. """
  351. features=features[0]
  352. B, _, H, W = features.shape
  353. lines_batch = []
  354. jmap = features[:, 0] # shape: [B, H, W]
  355. joff = features[:, 1:3] # shape: [B, 2, H, W]
  356. for b in range(B):
  357. jmap_b = jmap[b] # shape: [H, W]
  358. joff_b = joff[b] # shape: [2, H, W]
  359. # Flatten 并取 top-k 热点
  360. val_k, idx_k = torch.topk(jmap_b.view(-1), k=topk)
  361. ys = idx_k // W # 行号
  362. xs = idx_k % W # 列号
  363. # 获取偏移值
  364. dx = joff_b[0, ys, xs]
  365. dy = joff_b[1, ys, xs]
  366. # 校正坐标
  367. points = torch.stack([
  368. xs.float() + dx,
  369. ys.float() + dy
  370. ], dim=1) # shape: [topk, 2]
  371. # 两两组合成线段
  372. num_points = points.shape[0]
  373. if num_points < 2:
  374. lines_batch.append(torch.empty((0, 4), device=features.device))
  375. continue
  376. idx_i, idx_j = torch.triu_indices(num_points, num_points, offset=1)
  377. point_i = points[idx_i]
  378. point_j = points[idx_j]
  379. lines = torch.cat([point_i, point_j], dim=1) # shape: [N, 4]
  380. lines_batch.append(lines)
  381. print(f'lines_batch:{lines_batch[0].shape}')
  382. return lines_batch
  383. def filter_lines_inside_boxes(self,lines: torch.Tensor, boxes: List[torch.Tensor]):
  384. """
  385. Args:
  386. lines: [N, 4] 线段,格式为 [x1, y1, x2, y2]
  387. boxes: List of [K_i, 4],每张图像的 proposal boxes
  388. Returns:
  389. filtered_lines_per_image: List[Tensor], 每个元素是该图像中位于 box 内的线段
  390. """
  391. filtered_lines = []
  392. for box in boxes:
  393. # box shape: [K, 4]
  394. line_masks = []
  395. for i in range(box.shape[0]):
  396. bx0, by0, bx1, by1 = box[i]
  397. # 获取线段两端点
  398. x1, y1, x2, y2 = lines[:, 0], lines[:, 1], lines[:, 2], lines[:, 3]
  399. # 判断两个端点是否都在 box 内
  400. in_box1 = (x1 >= bx0) & (y1 >= by0) & (x1 <= bx1) & (y1 <= by1)
  401. in_box2 = (x2 >= bx0) & (y2 >= by0) & (x2 <= bx1) & (y2 <= by1)
  402. mask = in_box1 & in_box2 # 两个端点都在 box 内
  403. line_masks.append(mask)
  404. if len(line_masks) == 0:
  405. filtered_lines.append(torch.empty((0, 4), device=lines.device))
  406. else:
  407. combined_mask = torch.stack(line_masks).any(dim=0) # 只要在一个 box 内即可
  408. filtered_line = lines[combined_mask]
  409. filtered_lines.append(filtered_line)
  410. return filtered_lines
  411. def non_maximum_suppression(a):
  412. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  413. mask = (a == ap).float().clamp(min=0.0)
  414. return a * mask