import math import os import sys from datetime import datetime from typing import Mapping, Any import cv2 import numpy as np import torch import torchvision from torch import nn from torchvision.io import read_image from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor from torchvision.utils import draw_bounding_boxes from torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from models.config.config_tool import read_yaml from models.keypoint.trainer import train_cfg from tools import utils os.environ['CUDA_LAUNCH_BLOCKING'] = '1' class KeypointRCNNModel(nn.Module): def __init__(self, num_classes=2,num_keypoints=2, transforms=None): super(KeypointRCNNModel, self).__init__() default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT self.__model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=None,num_classes=num_classes, num_keypoints=num_keypoints, progress=False) if transforms is None: self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms() # if num_classes != 0: # self.set_num_classes(num_classes) # self.__num_classes=0 self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') def forward(self, inputs): outputs = self.__model(inputs) return outputs def train(self, cfg): parameters = read_yaml(cfg) num_classes = parameters['num_classes'] num_keypoints = parameters['num_keypoints'] # print(f'num_classes:{num_classes}') # self.set_num_classes(num_classes) self.num_keypoints = num_keypoints train_cfg(self.__model, cfg) # def set_num_classes(self, num_classes): # in_features = self.__model.roi_heads.box_predictor.cls_score.in_features # self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes) # # # in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels # in_channels = self.__model.roi_heads.keypoint_predictor. # hidden_layer = 256 # self.__model.roi_heads.mask_predictor = KeypointRCNNPredictor(in_channels, hidden_layer, # num_classes=num_classes) # self.__model.roi_heads.keypoint_predictor=KeypointRCNNPredictor(in_channels, num_keypoints=num_classes) def load_weight(self, pt_path): state_dict = torch.load(pt_path) self.__model.load_state_dict(state_dict) def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): self.__model.load_state_dict(state_dict) # return super().load_state_dict(state_dict, strict) if __name__ == '__main__': # ins_model = MaskRCNNModel(num_classes=5) keypoint_model = KeypointRCNNModel(num_keypoints=2) # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt' # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5) keypoint_model.train(cfg='train.yaml')