maskrcnn_dataset.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import os
  2. import PIL
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from matplotlib import pyplot as plt
  7. from torch.utils.data import Dataset
  8. from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
  9. from models.dataset_tool import masks_to_boxes, read_masks_from_txt, read_masks_from_pixels
  10. class MaskRCNNDataset(Dataset):
  11. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='polygon'):
  12. self.data_path = dataset_path
  13. self.transforms = transforms
  14. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  15. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  16. self.imgs = os.listdir(self.img_path)
  17. self.lbls = os.listdir(self.lbl_path)
  18. self.target_type = target_type
  19. self.deafult_transform= MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
  20. # print('maskrcnn inited!')
  21. def __getitem__(self, item):
  22. # print('__getitem__')
  23. img_path = os.path.join(self.img_path, self.imgs[item])
  24. lbl_path = os.path.join(self.lbl_path, self.imgs[item][:-3] + 'txt')
  25. img = PIL.Image.open(img_path).convert('RGB')
  26. # h, w = np.array(img).shape[:2]
  27. w, h = img.size
  28. # print(f'h,w:{h, w}')
  29. target = self.read_target(item=item, lbl_path=lbl_path, shape=(h, w))
  30. if self.transforms:
  31. img, target = self.transforms(img,target)
  32. else:
  33. img=self.deafult_transform(img)
  34. # print(f'img:{img.shape},target:{target}')
  35. return img, target
  36. def create_masks_from_polygons(self, polygons, image_shape):
  37. """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
  38. colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
  39. masks = []
  40. for polygon_data, col in zip(polygons, colors):
  41. mask = np.zeros(image_shape[:2], dtype=np.uint8)
  42. # 将多边形顶点转换为 NumPy 数组
  43. _, polygon = polygon_data
  44. pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
  45. # 使用 OpenCV 的 fillPoly 函数填充多边形
  46. # print(f'color:{col[:3]}')
  47. cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
  48. mask = torch.from_numpy(mask)
  49. mask[mask != 0] = 1
  50. masks.append(mask)
  51. return masks
  52. def read_target(self, item, lbl_path, shape):
  53. # print(f'lbl_path:{lbl_path}')
  54. h, w = shape
  55. labels = []
  56. masks = []
  57. if self.target_type == 'polygon':
  58. labels, masks = read_masks_from_txt(lbl_path, shape)
  59. elif self.target_type == 'pixel':
  60. labels, masks = read_masks_from_pixels(lbl_path, shape)
  61. target = {}
  62. target["boxes"] = masks_to_boxes(torch.stack(masks))
  63. target["labels"] = torch.stack(labels)
  64. target["masks"] = torch.stack(masks)
  65. target["image_id"] = torch.tensor(item)
  66. target["area"] = torch.zeros(len(masks))
  67. target["iscrowd"] = torch.zeros(len(masks))
  68. return target
  69. def heatmap_enhance(self, img):
  70. # 直方图均衡化
  71. img_eq = cv2.equalizeHist(img)
  72. # 自适应直方图均衡化
  73. # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
  74. # img_clahe = clahe.apply(img)
  75. # 将灰度图转换为热力图
  76. heatmap = cv2.applyColorMap(img_eq, cv2.COLORMAP_HOT)
  77. def __len__(self):
  78. return len(self.imgs)