fasterrcnn_resnet50.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import torch
  2. import torch.nn as nn
  3. import torchvision
  4. from torchvision.models.detection.transform import GeneralizedRCNNTransform
  5. # from .detection.transform import GeneralizedRCNNTransform
  6. def get_model(num_classes):
  7. # 加载预训练的ResNet-50 FPN backbone
  8. model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
  9. # 获取分类器的输入特征数
  10. in_features = model.roi_heads.box_predictor.cls_score.in_features
  11. # 替换分类器以适应新的类别数量
  12. model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
  13. return model
  14. class Fasterrcnn_resnet50(nn.Module):
  15. def __init__(self, num_classes=5, num_stacks=1):
  16. super(Fasterrcnn_resnet50, self).__init__()
  17. self.model = get_model(num_classes=5)
  18. self.backbone = self.model.backbone
  19. # self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
  20. # out_channels = self.backbone.out_channels
  21. # resolution = self.box_roi_pool.output_size[0]
  22. # representation_size = 1024
  23. # self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
  24. #
  25. # self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
  26. # 多任务输出层
  27. self.score_layers = nn.ModuleList([
  28. nn.Sequential(
  29. nn.Conv2d(256, 128, kernel_size=3, padding=1),
  30. nn.BatchNorm2d(128),
  31. nn.ReLU(inplace=True),
  32. nn.Conv2d(128, num_classes, kernel_size=1)
  33. )
  34. for _ in range(num_stacks)
  35. ])
  36. def forward(self, x, target1, train_or_val, image_shapes=(512, 512)):
  37. transform = GeneralizedRCNNTransform(min_size=512, max_size=1333, image_mean=[0.485, 0.456, 0.406],
  38. image_std=[0.229, 0.224, 0.225])
  39. images, targets = transform(x, target1)
  40. x_ = self.backbone(images.tensors)
  41. # x_ = self.backbone(x) # '0' '1' '2' '3' 'pool'
  42. # print(f'backbone:{self.backbone}')
  43. # print(f'Fasterrcnn_resnet50 x_:{x_}')
  44. feature_ = x_['0'] # 图片特征
  45. outputs = []
  46. for score_layer in self.score_layers:
  47. output = score_layer(feature_)
  48. outputs.append(output) # 多头
  49. if train_or_val == "training":
  50. loss_box = self.model(x, target1)
  51. return outputs, feature_, loss_box
  52. else:
  53. box_all = self.model(x, target1)
  54. return outputs, feature_, box_all
  55. def fasterrcnn_resnet50(**kwargs):
  56. model = Fasterrcnn_resnet50(
  57. num_classes=kwargs.get("num_classes", 5),
  58. num_stacks=kwargs.get("num_stacks", 1)
  59. )
  60. return model