import os import PIL import cv2 import numpy as np import torch from matplotlib import pyplot as plt from torch.utils.data import Dataset from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights from models.dataset_tool import masks_to_boxes, read_masks_from_txt, read_masks_from_pixels class MaskRCNNDataset(Dataset): def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='polygon'): self.data_path = dataset_path self.transforms = transforms self.img_path = os.path.join(dataset_path, "images/" + dataset_type) self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type) self.imgs = os.listdir(self.img_path) self.lbls = os.listdir(self.lbl_path) self.target_type = target_type self.deafult_transform= MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms() # print('maskrcnn inited!') def __getitem__(self, item): # print('__getitem__') img_path = os.path.join(self.img_path, self.imgs[item]) lbl_path = os.path.join(self.lbl_path, self.imgs[item][:-3] + 'txt') img = PIL.Image.open(img_path).convert('RGB') # h, w = np.array(img).shape[:2] w, h = img.size # print(f'h,w:{h, w}') target = self.read_target(item=item, lbl_path=lbl_path, shape=(h, w)) if self.transforms: img, target = self.transforms(img,target) else: img=self.deafult_transform(img) # print(f'img:{img.shape},target:{target}') return img, target def create_masks_from_polygons(self, polygons, image_shape): """创建一个与图像尺寸相同的掩码,并填充多边形轮廓""" colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))]) masks = [] for polygon_data, col in zip(polygons, colors): mask = np.zeros(image_shape[:2], dtype=np.uint8) # 将多边形顶点转换为 NumPy 数组 _, polygon = polygon_data pts = np.array(polygon, np.int32).reshape((-1, 1, 2)) # 使用 OpenCV 的 fillPoly 函数填充多边形 # print(f'color:{col[:3]}') cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255) mask = torch.from_numpy(mask) mask[mask != 0] = 1 masks.append(mask) return masks def read_target(self, item, lbl_path, shape): # print(f'lbl_path:{lbl_path}') h, w = shape labels = [] masks = [] if self.target_type == 'polygon': labels, masks = read_masks_from_txt(lbl_path, shape) elif self.target_type == 'pixel': labels, masks = read_masks_from_pixels(lbl_path, shape) target = {} target["boxes"] = masks_to_boxes(torch.stack(masks)) target["labels"] = torch.stack(labels) target["masks"] = torch.stack(masks) target["image_id"] = torch.tensor(item) target["area"] = torch.zeros(len(masks)) target["iscrowd"] = torch.zeros(len(masks)) return target def heatmap_enhance(self, img): # 直方图均衡化 img_eq = cv2.equalizeHist(img) # 自适应直方图均衡化 # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) # img_clahe = clahe.apply(img) # 将灰度图转换为热力图 heatmap = cv2.applyColorMap(img_eq, cv2.COLORMAP_HOT) def __len__(self): return len(self.imgs)