fasterrcnn_resnet50.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import torch
  2. import torch.nn as nn
  3. import torchvision
  4. from typing import Dict, List, Optional, Tuple
  5. import torch.nn.functional as F
  6. from torchvision.ops import MultiScaleRoIAlign
  7. from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
  8. from torchvision.models.detection.transform import GeneralizedRCNNTransform
  9. def get_model(num_classes):
  10. # 加载预训练的ResNet-50 FPN backbone
  11. model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
  12. # 获取分类器的输入特征数
  13. in_features = model.roi_heads.box_predictor.cls_score.in_features
  14. # 替换分类器以适应新的类别数量
  15. model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
  16. return model
  17. def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
  18. # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
  19. """
  20. Computes the loss for Faster R-CNN.
  21. Args:
  22. class_logits (Tensor)
  23. box_regression (Tensor)
  24. labels (list[BoxList])
  25. regression_targets (Tensor)
  26. Returns:
  27. classification_loss (Tensor)
  28. box_loss (Tensor)
  29. """
  30. labels = torch.cat(labels, dim=0)
  31. regression_targets = torch.cat(regression_targets, dim=0)
  32. classification_loss = F.cross_entropy(class_logits, labels)
  33. # get indices that correspond to the regression targets for
  34. # the corresponding ground truth labels, to be used with
  35. # advanced indexing
  36. sampled_pos_inds_subset = torch.where(labels > 0)[0]
  37. labels_pos = labels[sampled_pos_inds_subset]
  38. N, num_classes = class_logits.shape
  39. box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
  40. box_loss = F.smooth_l1_loss(
  41. box_regression[sampled_pos_inds_subset, labels_pos],
  42. regression_targets[sampled_pos_inds_subset],
  43. beta=1 / 9,
  44. reduction="sum",
  45. )
  46. box_loss = box_loss / labels.numel()
  47. return classification_loss, box_loss
  48. class Fasterrcnn_resnet50(nn.Module):
  49. def __init__(self, num_classes=5, num_stacks=1):
  50. super(Fasterrcnn_resnet50, self).__init__()
  51. self.model = get_model(num_classes=5)
  52. self.backbone = self.model.backbone
  53. self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
  54. out_channels = self.backbone.out_channels
  55. resolution = self.box_roi_pool.output_size[0]
  56. representation_size = 1024
  57. self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
  58. self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
  59. # 多任务输出层
  60. self.score_layers = nn.ModuleList([
  61. nn.Sequential(
  62. nn.Conv2d(256, 128, kernel_size=3, padding=1),
  63. nn.BatchNorm2d(128),
  64. nn.ReLU(inplace=True),
  65. nn.Conv2d(128, num_classes, kernel_size=1)
  66. )
  67. for _ in range(num_stacks)
  68. ])
  69. def forward(self, x, target1, train_or_val, image_shapes=(512, 512)):
  70. transform = GeneralizedRCNNTransform(min_size=512, max_size=1333, image_mean=[0.485, 0.456, 0.406],
  71. image_std=[0.229, 0.224, 0.225])
  72. images, targets = transform(x, target1)
  73. x_ = self.backbone(images.tensors)
  74. # x_ = self.backbone(x) # '0' '1' '2' '3' 'pool'
  75. # print(f'backbone:{self.backbone}')
  76. # print(f'Fasterrcnn_resnet50 x_:{x_}')
  77. feature_ = x_['0'] # 图片特征
  78. outputs = []
  79. for score_layer in self.score_layers:
  80. output = score_layer(feature_)
  81. outputs.append(output) # 多头
  82. if train_or_val == "training":
  83. loss_box = self.model(x, target1)
  84. return outputs, feature_, loss_box
  85. else:
  86. box_all = self.model(x, target1)
  87. return outputs, feature_, box_all
  88. def fasterrcnn_resnet50(**kwargs):
  89. model = Fasterrcnn_resnet50(
  90. num_classes=kwargs.get("num_classes", 5),
  91. num_stacks=kwargs.get("num_stacks", 1)
  92. )
  93. return model