maskrcnn.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from typing import Mapping, Any
  2. import cv2
  3. import numpy as np
  4. import torch
  5. from torch import nn
  6. from torchvision.io import read_image
  7. from torchvision import models
  8. from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
  9. from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
  10. from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
  11. from torchvision.utils import draw_bounding_boxes
  12. from models.config.config_tool import read_yaml
  13. from models.ins_detect.trainer import train_cfg
  14. from tools import utils
  15. class MaskRCNNModel(nn.Module):
  16. def __init__(self, num_classes=0, transforms=None):
  17. super(MaskRCNNModel, self).__init__()
  18. self.__model =models.detection.maskrcnn_resnet50_fpn_v2(
  19. weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
  20. if transforms is None:
  21. self.transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
  22. if num_classes != 0:
  23. self.set_num_classes(num_classes)
  24. # self.__num_classes=0
  25. self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  26. def forward(self, inputs):
  27. outputs = self.__model(inputs)
  28. return outputs
  29. def train(self, cfg):
  30. parameters = read_yaml(cfg)
  31. num_classes=parameters['num_classes']
  32. # print(f'num_classes:{num_classes}')
  33. self.set_num_classes(num_classes)
  34. train_cfg(self.__model, cfg)
  35. def set_num_classes(self, num_classes):
  36. in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
  37. self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
  38. in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
  39. hidden_layer = 256
  40. self.__model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer,
  41. num_classes=num_classes)
  42. def load_weight(self, pt_path):
  43. state_dict = torch.load(pt_path)
  44. self.__model.load_state_dict(state_dict)
  45. def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
  46. self.__model.load_state_dict(state_dict)
  47. # return super().load_state_dict(state_dict, strict)
  48. def predict(self, src, show_box=True, show_mask=True):
  49. self.__model.eval()
  50. img = read_image(src)
  51. img = self.transforms(img)
  52. img = img.to(self.device)
  53. result = self.__model([img])
  54. print(f'result:{result}')
  55. masks = result[0]['masks']
  56. boxes = result[0]['boxes']
  57. # cv2.imshow('ins',masks[0].cpu().detach().numpy())
  58. boxes = boxes.cpu().detach()
  59. drawn_boxes = draw_bounding_boxes((img * 255).to(torch.uint8), boxes, colors="red", width=5)
  60. print(f'drawn_boxes:{drawn_boxes.shape}')
  61. boxed_img = drawn_boxes.permute(1, 2, 0).numpy()
  62. # boxed_img=cv2.resize(boxed_img,(800,800))
  63. # cv2.imshow('boxes',boxed_img)
  64. mask = masks[0].cpu().detach().permute(1, 2, 0).numpy()
  65. mask = cv2.resize(mask, (800, 800))
  66. # cv2.imshow('ins',ins)
  67. img = img.cpu().detach().permute(1, 2, 0).numpy()
  68. masked_img = self.overlay_masks_on_image(boxed_img, masks)
  69. masked_img = cv2.resize(masked_img, (800, 800))
  70. cv2.imshow('img_masks', masked_img)
  71. # show_img_boxes_masks(img, boxes, masks)
  72. cv2.waitKey(0)
  73. def generate_colors(self, n):
  74. """
  75. 生成n个均匀分布在HSV色彩空间中的颜色,并转换成BGR色彩空间。
  76. :param n: 需要的颜色数量
  77. :return: 一个包含n个颜色的列表,每个颜色为BGR格式的元组
  78. """
  79. hsv_colors = [(i / n * 180, 1 / 3 * 255, 2 / 3 * 255) for i in range(n)]
  80. bgr_colors = [tuple(map(int, cv2.cvtColor(np.uint8([[hsv]]), cv2.COLOR_HSV2BGR)[0][0])) for hsv in hsv_colors]
  81. return bgr_colors
  82. def overlay_masks_on_image(self, image, masks, alpha=0.6):
  83. """
  84. 在原图上叠加多个掩码,每个掩码使用不同的颜色。
  85. :param image: 原图 (NumPy 数组)
  86. :param masks: 掩码列表 (每个都是 NumPy 数组,二值图像)
  87. :param colors: 颜色列表 (每个颜色都是 (B, G, R) 格式的元组)
  88. :param alpha: 掩码的透明度 (0.0 到 1.0)
  89. :return: 叠加了多个掩码的图像
  90. """
  91. colors = self.generate_colors(len(masks))
  92. if len(masks) != len(colors):
  93. raise ValueError("The number of masks and colors must be the same.")
  94. # 复制原图,避免修改原始图像
  95. overlay = image.copy()
  96. for mask, color in zip(masks, colors):
  97. # 确保掩码是二值图像
  98. mask = mask.cpu().detach().permute(1, 2, 0).numpy()
  99. binary_mask = (mask > 0).astype(np.uint8) * 255 # 你可以根据实际情况调整阈值
  100. # 创建彩色掩码
  101. colored_mask = np.zeros_like(image)
  102. colored_mask[:] = color
  103. colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask)
  104. # 将彩色掩码与当前的叠加图像混合
  105. overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0)
  106. return overlay
  107. if __name__ == '__main__':
  108. # ins_model = MaskRCNNModel(num_classes=5)
  109. ins_model = MaskRCNNModel()
  110. # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
  111. # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
  112. ins_model.train(cfg='train.yaml')