base_dataset.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from abc import ABC, abstractmethod
  2. import torch
  3. from torch import nn, Tensor
  4. from torch.utils.data import Dataset
  5. from torch.utils.data.dataset import T_co
  6. from torchvision.transforms import functional as F
  7. class BaseDataset(Dataset, ABC):
  8. def __init__(self,dataset_path):
  9. self.default_transform=DefaultTransform()
  10. pass
  11. def __getitem__(self, index) -> T_co:
  12. pass
  13. @abstractmethod
  14. def read_target(self,item,lbl_path,extra=None):
  15. pass
  16. """显示数据集指定图片"""
  17. @abstractmethod
  18. def show(self,idx):
  19. pass
  20. """
  21. 显示数据集指定名字的图片
  22. """
  23. @abstractmethod
  24. def show_img(self,img_path):
  25. pass
  26. class DefaultTransform(nn.Module):
  27. def forward(self, img: Tensor) -> Tensor:
  28. if not isinstance(img, Tensor):
  29. img = F.pil_to_tensor(img)
  30. return F.convert_image_dtype(img, torch.float)
  31. def __repr__(self) -> str:
  32. return self.__class__.__name__ + "()"
  33. def describe(self) -> str:
  34. return (
  35. "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
  36. "The images are rescaled to ``[0.0, 1.0]``."
  37. )