kepointrcnn.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import math
  2. import os
  3. import sys
  4. from datetime import datetime
  5. from typing import Mapping, Any
  6. import cv2
  7. import numpy as np
  8. import torch
  9. import torchvision
  10. from torch import nn
  11. from torchvision.io import read_image
  12. from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
  13. from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
  14. from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
  15. from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
  16. from torchvision.utils import draw_bounding_boxes
  17. from torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
  18. from models.config.config_tool import read_yaml
  19. from models.keypoint.trainer import train_cfg
  20. from tools import utils
  21. os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
  22. class KeypointRCNNModel(nn.Module):
  23. def __init__(self, num_classes=2,num_keypoints=2, transforms=None):
  24. super(KeypointRCNNModel, self).__init__()
  25. default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
  26. self.__model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=None,num_classes=num_classes,
  27. num_keypoints=num_keypoints,
  28. progress=False)
  29. if transforms is None:
  30. self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
  31. # if num_classes != 0:
  32. # self.set_num_classes(num_classes)
  33. # self.__num_classes=0
  34. self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  35. def forward(self, inputs):
  36. outputs = self.__model(inputs)
  37. return outputs
  38. def train(self, cfg):
  39. parameters = read_yaml(cfg)
  40. num_classes = parameters['num_classes']
  41. num_keypoints = parameters['num_keypoints']
  42. # print(f'num_classes:{num_classes}')
  43. # self.set_num_classes(num_classes)
  44. self.num_keypoints = num_keypoints
  45. train_cfg(self.__model, cfg)
  46. # def set_num_classes(self, num_classes):
  47. # in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
  48. # self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
  49. #
  50. # # in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
  51. # in_channels = self.__model.roi_heads.keypoint_predictor.
  52. # hidden_layer = 256
  53. # self.__model.roi_heads.mask_predictor = KeypointRCNNPredictor(in_channels, hidden_layer,
  54. # num_classes=num_classes)
  55. # self.__model.roi_heads.keypoint_predictor=KeypointRCNNPredictor(in_channels, num_keypoints=num_classes)
  56. def load_weight(self, pt_path):
  57. state_dict = torch.load(pt_path)
  58. self.__model.load_state_dict(state_dict)
  59. def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
  60. self.__model.load_state_dict(state_dict)
  61. # return super().load_state_dict(state_dict, strict)
  62. if __name__ == '__main__':
  63. # ins_model = MaskRCNNModel(num_classes=5)
  64. keypoint_model = KeypointRCNNModel(num_keypoints=2)
  65. # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
  66. # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
  67. keypoint_model.train(cfg='train.yaml')