123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- import os
- import PIL
- import cv2
- import numpy as np
- import torch
- from matplotlib import pyplot as plt
- from torch.utils.data import Dataset
- from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
- from models.dataset_tool import masks_to_boxes, read_masks_from_txt, read_masks_from_pixels
- class MaskRCNNDataset(Dataset):
- def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='polygon'):
- self.data_path = dataset_path
- self.transforms = transforms
- self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
- self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
- self.imgs = os.listdir(self.img_path)
- self.lbls = os.listdir(self.lbl_path)
- self.target_type = target_type
- self.deafult_transform= MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
- # print('maskrcnn inited!')
- def __getitem__(self, item):
- # print('__getitem__')
- img_path = os.path.join(self.img_path, self.imgs[item])
- lbl_path = os.path.join(self.lbl_path, self.imgs[item][:-3] + 'txt')
- img = PIL.Image.open(img_path).convert('RGB')
- # h, w = np.array(img).shape[:2]
- w, h = img.size
- # print(f'h,w:{h, w}')
- target = self.read_target(item=item, lbl_path=lbl_path, shape=(h, w))
- if self.transforms:
- img, target = self.transforms(img,target)
- else:
- img=self.deafult_transform(img)
- # print(f'img:{img.shape},target:{target}')
- return img, target
- def create_masks_from_polygons(self, polygons, image_shape):
- """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
- colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
- masks = []
- for polygon_data, col in zip(polygons, colors):
- mask = np.zeros(image_shape[:2], dtype=np.uint8)
- # 将多边形顶点转换为 NumPy 数组
- _, polygon = polygon_data
- pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
- # 使用 OpenCV 的 fillPoly 函数填充多边形
- # print(f'color:{col[:3]}')
- cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
- mask = torch.from_numpy(mask)
- mask[mask != 0] = 1
- masks.append(mask)
- return masks
- def read_target(self, item, lbl_path, shape):
- # print(f'lbl_path:{lbl_path}')
- h, w = shape
- labels = []
- masks = []
- if self.target_type == 'polygon':
- labels, masks = read_masks_from_txt(lbl_path, shape)
- elif self.target_type == 'pixel':
- labels, masks = read_masks_from_pixels(lbl_path, shape)
- target = {}
- target["boxes"] = masks_to_boxes(torch.stack(masks))
- target["labels"] = torch.stack(labels)
- target["masks"] = torch.stack(masks)
- target["image_id"] = torch.tensor(item)
- target["area"] = torch.zeros(len(masks))
- target["iscrowd"] = torch.zeros(len(masks))
- return target
- def heatmap_enhance(self, img):
- # 直方图均衡化
- img_eq = cv2.equalizeHist(img)
- # 自适应直方图均衡化
- # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
- # img_clahe = clahe.apply(img)
- # 将灰度图转换为热力图
- heatmap = cv2.applyColorMap(img_eq, cv2.COLORMAP_HOT)
- def __len__(self):
- return len(self.imgs)
|