kepointrcnn.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import math
  2. import os
  3. import sys
  4. from collections import OrderedDict
  5. from datetime import datetime
  6. from typing import Mapping
  7. import cv2
  8. import numpy as np
  9. import torch
  10. import torchvision
  11. from PIL import Image
  12. from matplotlib import pyplot as plt
  13. from torch import nn
  14. from torch.nn.modules.module import T
  15. from torchvision.io import read_image
  16. from torchvision.models import resnet50, ResNet50_Weights, resnet18, ResNet18_Weights
  17. from torchvision.models._utils import _ovewrite_value_param
  18. from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
  19. from torchvision.models.detection.anchor_utils import AnchorGenerator
  20. from torchvision.models.detection.backbone_utils import _validate_trainable_layers, _resnet_fpn_extractor
  21. from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
  22. from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor, KeypointRCNN, \
  23. KeypointRCNN_ResNet50_FPN_Weights
  24. from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
  25. from torchvision.ops.feature_pyramid_network import LastLevelMaxPool
  26. from torchvision.utils import draw_bounding_boxes
  27. from torchvision.ops import misc as misc_nn_ops, FeaturePyramidNetwork
  28. from typing import Optional, Any
  29. from models.config.config_tool import read_yaml
  30. from models.keypoint.trainer import train_cfg
  31. from models.wirenet._utils import overwrite_eps
  32. # from timm import create_model
  33. from torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
  34. from tools import utils
  35. os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
  36. class KeypointRCNNModel(nn.Module):
  37. def __init__(self, num_classes=2,num_keypoints=2, transforms=None):
  38. super(KeypointRCNNModel, self).__init__()
  39. ####mobile net
  40. # backbone = torchvision.models.mobilenet_v2(weights=None).features
  41. # backbone.out_channels = 1280
  42. # anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),aspect_ratios = ((0.5, 1.0, 2.0),))
  43. # roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],output_size = 7,sampling_ratio = 2)
  44. #keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],output_size = 14,sampling_ratio = 2)
  45. # self.__model= KeypointRCNN(backbone, num_classes=2, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler,keypoint_roi_pool=keypoint_roi_pooler)
  46. ####
  47. # 加载 EfficientNet 模型并移除分类头
  48. # backbone = create_model('tf_efficientnet_b0', pretrained=True, features_only=True)
  49. # backbone_out_channels =backbone.feature_info.channels() # 获取所有阶段的通道数
  50. #
  51. #
  52. # # 构建 FPN
  53. # fpn = FeaturePyramidNetwork(
  54. # in_channels_list=backbone_out_channels,
  55. # out_channels=256,
  56. # extra_blocks=LastLevelMaxPool()
  57. # )
  58. #
  59. # # 将 EfficientNet 和 FPN 组合成一个新的 backbone
  60. # self.body = nn.Sequential(
  61. # backbone,
  62. # fpn
  63. # )
  64. default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
  65. self.__model = keypointrcnn_resnet18_fpn(weights=None,num_classes=num_classes,
  66. num_keypoints=num_keypoints,
  67. progress=False)
  68. # self.__model.backbone.body = nn.Sequential(OrderedDict([
  69. # ('body', self.body),
  70. # ('fpn', fpn)
  71. # ]))
  72. if transforms is None:
  73. self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
  74. # if num_classes != 0:
  75. # self.set_num_classes(num_classes)
  76. # self.__num_classes=0
  77. self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  78. def forward(self, inputs):
  79. outputs = self.__model(inputs)
  80. return outputs
  81. def train(self, cfg):
  82. parameters = read_yaml(cfg)
  83. num_classes = parameters['num_classes']
  84. num_keypoints = parameters['num_keypoints']
  85. # print(f'num_classes:{num_classes}')
  86. # self.set_num_classes(num_classes)
  87. self.num_keypoints = num_keypoints
  88. train_cfg(self.__model, cfg)
  89. # def set_num_classes(self, num_classes):
  90. # in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
  91. # self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
  92. #
  93. # # in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
  94. # in_channels = self.__model.roi_heads.keypoint_predictor.
  95. # hidden_layer = 256
  96. # self.__model.roi_heads.mask_predictor = KeypointRCNNPredictor(in_channels, hidden_layer,
  97. # num_classes=num_classes)
  98. # self.__model.roi_heads.keypoint_predictor=KeypointRCNNPredictor(in_channels, num_keypoints=num_classes)
  99. def load_weight(self, pt_path):
  100. state_dict = torch.load(pt_path)
  101. self.__model.load_state_dict(state_dict)
  102. def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
  103. self.__model.load_state_dict(state_dict)
  104. # return super().load_state_dict(state_dict, strict)
  105. def eval(self: T) -> T:
  106. self.__model.eval()
  107. # return super().eval()
  108. def predict(self, img, show=True, save=False, save_path=None):
  109. """
  110. 对输入图像进行关键点检测预测。
  111. 参数:
  112. img (str or PIL.Image): 输入图像的路径或 PIL.Image 对象。
  113. show (bool): 是否显示预测结果,默认为 True。
  114. save (bool): 是否保存预测结果,默认为 False。
  115. 返回:
  116. dict: 包含预测结果的字典。
  117. """
  118. if isinstance(img, str):
  119. img = Image.open(img).convert("RGB")
  120. self.__model.eval()
  121. # 预处理图像
  122. img_tensor = self.transforms(img)
  123. with torch.no_grad():
  124. predictions = self.__model([img_tensor])
  125. print(f'predictions:{predictions}')
  126. # 后处理预测结果
  127. boxes = predictions[0]['boxes'].cpu().numpy()
  128. keypoints = predictions[0]['keypoints'].cpu().numpy()
  129. # 可视化预测结果
  130. if show or save:
  131. fig, ax = plt.subplots(figsize=(10, 10))
  132. ax.imshow(np.array(img))
  133. for box in boxes:
  134. x0, y0, x1, y1 = box
  135. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1))
  136. for (a, b) in keypoints:
  137. ax.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=1)
  138. ax.scatter(a[0], a[1], c='red', s=2)
  139. ax.scatter(b[0], b[1], c='red', s=2)
  140. if show:
  141. plt.show()
  142. if save:
  143. fig.savefig(save_path)
  144. print(f"Prediction saved to {save_path}")
  145. plt.close(fig)
  146. def keypointrcnn_resnet18_fpn(
  147. *,
  148. weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
  149. progress: bool = True,
  150. num_classes: Optional[int] = None,
  151. num_keypoints: Optional[int] = None,
  152. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  153. trainable_backbone_layers: Optional[int] = None,
  154. **kwargs: Any,
  155. ) -> KeypointRCNN:
  156. """
  157. Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
  158. .. betastatus:: detection module
  159. Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
  160. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  161. image, and should be in ``0-1`` range. Different images can have different sizes.
  162. The behavior of the model changes depending on if it is in training or evaluation mode.
  163. During training, the model expects both the input tensors and targets (list of dictionary),
  164. containing:
  165. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  166. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  167. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  168. - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
  169. format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
  170. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  171. losses for both the RPN and the R-CNN, and the keypoint loss.
  172. During inference, the model requires only the input tensors, and returns the post-processed
  173. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  174. follows, where ``N`` is the number of detected instances:
  175. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  176. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  177. - labels (``Int64Tensor[N]``): the predicted labels for each instance
  178. - scores (``Tensor[N]``): the scores or each instance
  179. - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
  180. For more details on the output, you may refer to :ref:`instance_seg_output`.
  181. Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  182. Example::
  183. >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
  184. >>> model.eval()
  185. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  186. >>> predictions = model(x)
  187. >>>
  188. >>> # optionally, if you want to export the model to ONNX:
  189. >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
  190. Args:
  191. weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
  192. pretrained weights to use. See
  193. :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
  194. below for more details, and possible values. By default, no
  195. pre-trained weights are used.
  196. progress (bool): If True, displays a progress bar of the download to stderr
  197. num_classes (int, optional): number of output classes of the model (including the background)
  198. num_keypoints (int, optional): number of keypoints
  199. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  200. pretrained weights for the backbone.
  201. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
  202. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
  203. passed (the default) this value is set to 3.
  204. .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
  205. :members:
  206. """
  207. weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
  208. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  209. # if weights_backbone is None:
  210. weights_backbone = ResNet18_Weights.IMAGENET1K_V1
  211. if weights is not None:
  212. # weights_backbone = None
  213. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  214. num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
  215. else:
  216. if num_classes is None:
  217. num_classes = 2
  218. if num_keypoints is None:
  219. num_keypoints = 17
  220. is_trained = weights is not None or weights_backbone is not None
  221. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  222. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  223. backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  224. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  225. model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
  226. if weights is not None:
  227. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  228. if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
  229. overwrite_eps(model, 0.0)
  230. return model
  231. if __name__ == '__main__':
  232. # ins_model = MaskRCNNModel(num_classes=5)
  233. keypoint_model = KeypointRCNNModel(num_keypoints=2)
  234. wts_path='./train_results/20241227_231659/weights/best.pt'
  235. # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
  236. # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
  237. keypoint_model.train(cfg='train.yaml')
  238. # keypoint_model.load_weight(wts_path)
  239. # img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-10-02-15_SaveImage.png"
  240. # keypoint_model.predict(img_path)