123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- from abc import ABC, abstractmethod
- import torch
- from torch import nn, Tensor
- from torch.utils.data import Dataset
- from torch.utils.data.dataset import T_co
- from torchvision.transforms import functional as F
- class BaseDataset(Dataset, ABC):
- def __init__(self,dataset_path):
- self.default_transform=DefaultTransform()
- pass
- def __getitem__(self, index) -> T_co:
- pass
- @abstractmethod
- def read_target(self,item,lbl_path,extra=None):
- pass
- """显示数据集指定图片"""
- @abstractmethod
- def show(self,idx):
- pass
- """
- 显示数据集指定名字的图片
- """
- @abstractmethod
- def show_img(self,img_path):
- pass
- class DefaultTransform(nn.Module):
- def forward(self, img: Tensor) -> Tensor:
- if not isinstance(img, Tensor):
- img = F.pil_to_tensor(img)
- return F.convert_image_dtype(img, torch.float)
- def __repr__(self) -> str:
- return self.__class__.__name__ + "()"
- def describe(self) -> str:
- return (
- "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
- "The images are rescaled to ``[0.0, 1.0]``."
- )
|