12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- import math
- import os
- import sys
- from datetime import datetime
- from typing import Mapping, Any
- import cv2
- import numpy as np
- import torch
- import torchvision
- from torch import nn
- from torchvision.io import read_image
- from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
- from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
- from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
- from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
- from torchvision.utils import draw_bounding_boxes
- from torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
- from models.config.config_tool import read_yaml
- from models.keypoint.trainer import train_cfg
- from tools import utils
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
- class KeypointRCNNModel(nn.Module):
- def __init__(self, num_classes=2,num_keypoints=2, transforms=None):
- super(KeypointRCNNModel, self).__init__()
- default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
- self.__model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=None,num_classes=num_classes,
- num_keypoints=num_keypoints,
- progress=False)
- if transforms is None:
- self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
-
-
-
- self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
- def forward(self, inputs):
- outputs = self.__model(inputs)
- return outputs
- def train(self, cfg):
- parameters = read_yaml(cfg)
- num_classes = parameters['num_classes']
- num_keypoints = parameters['num_keypoints']
-
-
- self.num_keypoints = num_keypoints
- train_cfg(self.__model, cfg)
-
-
-
-
-
-
-
-
-
-
- def load_weight(self, pt_path):
- state_dict = torch.load(pt_path)
- self.__model.load_state_dict(state_dict)
- def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
- self.__model.load_state_dict(state_dict)
-
- if __name__ == '__main__':
-
- keypoint_model = KeypointRCNNModel(num_keypoints=2)
-
-
- keypoint_model.train(cfg='train.yaml')
|