kepointrcnn.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import math
  2. import os
  3. import sys
  4. from datetime import datetime
  5. from typing import Mapping, Any
  6. import cv2
  7. import numpy as np
  8. import torch
  9. import torchvision
  10. from torch import nn
  11. from torch.nn.modules.module import T
  12. from torchvision.io import read_image
  13. from torchvision.models import resnet50, ResNet50_Weights, resnet18, ResNet18_Weights
  14. from torchvision.models._utils import _ovewrite_value_param
  15. from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
  16. from torchvision.models.detection.anchor_utils import AnchorGenerator
  17. from torchvision.models.detection.backbone_utils import _validate_trainable_layers, _resnet_fpn_extractor
  18. from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
  19. from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor, KeypointRCNN, \
  20. KeypointRCNN_ResNet50_FPN_Weights
  21. from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
  22. from torchvision.utils import draw_bounding_boxes
  23. from torchvision.ops import misc as misc_nn_ops
  24. from typing import Optional, Any
  25. from models.config.config_tool import read_yaml
  26. from models.keypoint.trainer import train_cfg
  27. from models.wirenet._utils import overwrite_eps
  28. from torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
  29. from tools import utils
  30. os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
  31. class KeypointRCNNModel(nn.Module):
  32. def __init__(self, num_classes=2,num_keypoints=2, transforms=None):
  33. super(KeypointRCNNModel, self).__init__()
  34. ####mobile net
  35. # backbone = torchvision.models.mobilenet_v2(weights=None).features
  36. # backbone.out_channels = 1280
  37. # anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),aspect_ratios = ((0.5, 1.0, 2.0),))
  38. # roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],output_size = 7,sampling_ratio = 2)
  39. #keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],output_size = 14,sampling_ratio = 2)
  40. # self.__model= KeypointRCNN(backbone, num_classes=2, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler,keypoint_roi_pool=keypoint_roi_pooler)
  41. ####
  42. default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
  43. self.__model = keypointrcnn_resnet18_fpn(weights=None,num_classes=num_classes,
  44. num_keypoints=num_keypoints,
  45. progress=False)
  46. if transforms is None:
  47. self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
  48. # if num_classes != 0:
  49. # self.set_num_classes(num_classes)
  50. # self.__num_classes=0
  51. self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  52. def forward(self, inputs):
  53. outputs = self.__model(inputs)
  54. return outputs
  55. def train(self, cfg):
  56. parameters = read_yaml(cfg)
  57. num_classes = parameters['num_classes']
  58. num_keypoints = parameters['num_keypoints']
  59. # print(f'num_classes:{num_classes}')
  60. # self.set_num_classes(num_classes)
  61. self.num_keypoints = num_keypoints
  62. train_cfg(self.__model, cfg)
  63. # def set_num_classes(self, num_classes):
  64. # in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
  65. # self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
  66. #
  67. # # in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
  68. # in_channels = self.__model.roi_heads.keypoint_predictor.
  69. # hidden_layer = 256
  70. # self.__model.roi_heads.mask_predictor = KeypointRCNNPredictor(in_channels, hidden_layer,
  71. # num_classes=num_classes)
  72. # self.__model.roi_heads.keypoint_predictor=KeypointRCNNPredictor(in_channels, num_keypoints=num_classes)
  73. def load_weight(self, pt_path):
  74. state_dict = torch.load(pt_path)
  75. self.__model.load_state_dict(state_dict)
  76. def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
  77. self.__model.load_state_dict(state_dict)
  78. # return super().load_state_dict(state_dict, strict)
  79. def eval(self: T) -> T:
  80. self.__model.eval()
  81. # return super().eval()
  82. def keypointrcnn_resnet18_fpn(
  83. *,
  84. weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
  85. progress: bool = True,
  86. num_classes: Optional[int] = None,
  87. num_keypoints: Optional[int] = None,
  88. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  89. trainable_backbone_layers: Optional[int] = None,
  90. **kwargs: Any,
  91. ) -> KeypointRCNN:
  92. """
  93. Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
  94. .. betastatus:: detection module
  95. Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
  96. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  97. image, and should be in ``0-1`` range. Different images can have different sizes.
  98. The behavior of the model changes depending on if it is in training or evaluation mode.
  99. During training, the model expects both the input tensors and targets (list of dictionary),
  100. containing:
  101. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  102. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  103. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  104. - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
  105. format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
  106. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  107. losses for both the RPN and the R-CNN, and the keypoint loss.
  108. During inference, the model requires only the input tensors, and returns the post-processed
  109. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  110. follows, where ``N`` is the number of detected instances:
  111. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  112. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  113. - labels (``Int64Tensor[N]``): the predicted labels for each instance
  114. - scores (``Tensor[N]``): the scores or each instance
  115. - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
  116. For more details on the output, you may refer to :ref:`instance_seg_output`.
  117. Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  118. Example::
  119. >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
  120. >>> model.eval()
  121. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  122. >>> predictions = model(x)
  123. >>>
  124. >>> # optionally, if you want to export the model to ONNX:
  125. >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
  126. Args:
  127. weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
  128. pretrained weights to use. See
  129. :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
  130. below for more details, and possible values. By default, no
  131. pre-trained weights are used.
  132. progress (bool): If True, displays a progress bar of the download to stderr
  133. num_classes (int, optional): number of output classes of the model (including the background)
  134. num_keypoints (int, optional): number of keypoints
  135. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  136. pretrained weights for the backbone.
  137. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
  138. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
  139. passed (the default) this value is set to 3.
  140. .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
  141. :members:
  142. """
  143. weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
  144. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  145. # if weights_backbone is None:
  146. weights_backbone = ResNet18_Weights.IMAGENET1K_V1
  147. if weights is not None:
  148. # weights_backbone = None
  149. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  150. num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
  151. else:
  152. if num_classes is None:
  153. num_classes = 2
  154. if num_keypoints is None:
  155. num_keypoints = 17
  156. is_trained = weights is not None or weights_backbone is not None
  157. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  158. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  159. backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  160. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  161. model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
  162. if weights is not None:
  163. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  164. if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
  165. overwrite_eps(model, 0.0)
  166. return model
  167. if __name__ == '__main__':
  168. # ins_model = MaskRCNNModel(num_classes=5)
  169. keypoint_model = KeypointRCNNModel(num_keypoints=2)
  170. # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
  171. # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
  172. keypoint_model.train(cfg='train.yaml')