kepointrcnn.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import math
  2. import os
  3. import sys
  4. from datetime import datetime
  5. from typing import Mapping, Any
  6. import cv2
  7. import numpy as np
  8. import torch
  9. import torchvision
  10. from torch import nn
  11. from torchvision.io import read_image
  12. from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
  13. from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
  14. from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
  15. from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
  16. from torchvision.utils import draw_bounding_boxes
  17. from models.config.config_tool import read_yaml
  18. from models.keypoint.trainer import train_cfg
  19. from tools import utils
  20. os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
  21. class KeypointRCNNModel(nn.Module):
  22. def __init__(self, num_classes=2,num_keypoints=2, transforms=None):
  23. super(KeypointRCNNModel, self).__init__()
  24. default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
  25. self.__model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=None,num_classes=num_classes,
  26. num_keypoints=num_keypoints,
  27. progress=False)
  28. if transforms is None:
  29. self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
  30. # if num_classes != 0:
  31. # self.set_num_classes(num_classes)
  32. # self.__num_classes=0
  33. self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  34. def forward(self, inputs):
  35. outputs = self.__model(inputs)
  36. return outputs
  37. def train(self, cfg):
  38. parameters = read_yaml(cfg)
  39. num_classes = parameters['num_classes']
  40. num_keypoints = parameters['num_keypoints']
  41. # print(f'num_classes:{num_classes}')
  42. # self.set_num_classes(num_classes)
  43. self.num_keypoints = num_keypoints
  44. train_cfg(self.__model, cfg)
  45. # def set_num_classes(self, num_classes):
  46. # in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
  47. # self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
  48. #
  49. # # in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
  50. # in_channels = self.__model.roi_heads.keypoint_predictor.
  51. # hidden_layer = 256
  52. # self.__model.roi_heads.mask_predictor = KeypointRCNNPredictor(in_channels, hidden_layer,
  53. # num_classes=num_classes)
  54. # self.__model.roi_heads.keypoint_predictor=KeypointRCNNPredictor(in_channels, num_keypoints=num_classes)
  55. def load_weight(self, pt_path):
  56. state_dict = torch.load(pt_path)
  57. self.__model.load_state_dict(state_dict)
  58. def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
  59. self.__model.load_state_dict(state_dict)
  60. # return super().load_state_dict(state_dict, strict)
  61. if __name__ == '__main__':
  62. # ins_model = MaskRCNNModel(num_classes=5)
  63. keypoint_model = KeypointRCNNModel(num_keypoints=2)
  64. # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
  65. # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
  66. keypoint_model.train(cfg='train.yaml')