maskrcnn.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import math
  2. import os
  3. import sys
  4. from datetime import datetime
  5. from typing import Mapping, Any
  6. import cv2
  7. import numpy as np
  8. import torch
  9. import torchvision
  10. from torch import nn
  11. from torchvision.io import read_image
  12. from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
  13. from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
  14. from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
  15. from torchvision.utils import draw_bounding_boxes
  16. from models.config.config_tool import read_yaml
  17. from models.ins_detect.trainer import train_cfg
  18. from tools import utils
  19. class MaskRCNNModel(nn.Module):
  20. def __init__(self, num_classes=0, transforms=None):
  21. super(MaskRCNNModel, self).__init__()
  22. self.__model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(
  23. weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
  24. if transforms is None:
  25. self.transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
  26. if num_classes != 0:
  27. self.set_num_classes(num_classes)
  28. # self.__num_classes=0
  29. self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  30. def forward(self, inputs):
  31. outputs = self.__model(inputs)
  32. return outputs
  33. def train(self, cfg):
  34. parameters = read_yaml(cfg)
  35. num_classes=parameters['num_classes']
  36. # print(f'num_classes:{num_classes}')
  37. self.set_num_classes(num_classes)
  38. train_cfg(self.__model, cfg)
  39. def set_num_classes(self, num_classes):
  40. in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
  41. self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
  42. in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
  43. hidden_layer = 256
  44. self.__model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer,
  45. num_classes=num_classes)
  46. def load_weight(self, pt_path):
  47. state_dict = torch.load(pt_path)
  48. self.__model.load_state_dict(state_dict)
  49. def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
  50. self.__model.load_state_dict(state_dict)
  51. # return super().load_state_dict(state_dict, strict)
  52. def predict(self, src, show_box=True, show_mask=True):
  53. self.__model.eval()
  54. img = read_image(src)
  55. img = self.transforms(img)
  56. img = img.to(self.device)
  57. result = self.__model([img])
  58. print(f'result:{result}')
  59. masks = result[0]['masks']
  60. boxes = result[0]['boxes']
  61. # cv2.imshow('mask',masks[0].cpu().detach().numpy())
  62. boxes = boxes.cpu().detach()
  63. drawn_boxes = draw_bounding_boxes((img * 255).to(torch.uint8), boxes, colors="red", width=5)
  64. print(f'drawn_boxes:{drawn_boxes.shape}')
  65. boxed_img = drawn_boxes.permute(1, 2, 0).numpy()
  66. # boxed_img=cv2.resize(boxed_img,(800,800))
  67. # cv2.imshow('boxes',boxed_img)
  68. mask = masks[0].cpu().detach().permute(1, 2, 0).numpy()
  69. mask = cv2.resize(mask, (800, 800))
  70. # cv2.imshow('mask',mask)
  71. img = img.cpu().detach().permute(1, 2, 0).numpy()
  72. masked_img = self.overlay_masks_on_image(boxed_img, masks)
  73. masked_img = cv2.resize(masked_img, (800, 800))
  74. cv2.imshow('img_masks', masked_img)
  75. # show_img_boxes_masks(img, boxes, masks)
  76. cv2.waitKey(0)
  77. def generate_colors(self, n):
  78. """
  79. 生成n个均匀分布在HSV色彩空间中的颜色,并转换成BGR色彩空间。
  80. :param n: 需要的颜色数量
  81. :return: 一个包含n个颜色的列表,每个颜色为BGR格式的元组
  82. """
  83. hsv_colors = [(i / n * 180, 1 / 3 * 255, 2 / 3 * 255) for i in range(n)]
  84. bgr_colors = [tuple(map(int, cv2.cvtColor(np.uint8([[hsv]]), cv2.COLOR_HSV2BGR)[0][0])) for hsv in hsv_colors]
  85. return bgr_colors
  86. def overlay_masks_on_image(self, image, masks, alpha=0.6):
  87. """
  88. 在原图上叠加多个掩码,每个掩码使用不同的颜色。
  89. :param image: 原图 (NumPy 数组)
  90. :param masks: 掩码列表 (每个都是 NumPy 数组,二值图像)
  91. :param colors: 颜色列表 (每个颜色都是 (B, G, R) 格式的元组)
  92. :param alpha: 掩码的透明度 (0.0 到 1.0)
  93. :return: 叠加了多个掩码的图像
  94. """
  95. colors = self.generate_colors(len(masks))
  96. if len(masks) != len(colors):
  97. raise ValueError("The number of masks and colors must be the same.")
  98. # 复制原图,避免修改原始图像
  99. overlay = image.copy()
  100. for mask, color in zip(masks, colors):
  101. # 确保掩码是二值图像
  102. mask = mask.cpu().detach().permute(1, 2, 0).numpy()
  103. binary_mask = (mask > 0).astype(np.uint8) * 255 # 你可以根据实际情况调整阈值
  104. # 创建彩色掩码
  105. colored_mask = np.zeros_like(image)
  106. colored_mask[:] = color
  107. colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask)
  108. # 将彩色掩码与当前的叠加图像混合
  109. overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0)
  110. return overlay
  111. if __name__ == '__main__':
  112. # ins_model = MaskRCNNModel(num_classes=5)
  113. ins_model = MaskRCNNModel()
  114. # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
  115. # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
  116. ins_model.train(cfg='train.yaml')