12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153 |
- import math
- import numbers
- import random
- import warnings
- from collections.abc import Sequence
- from typing import List, Optional, Tuple, Union
- import torch
- from torch import Tensor
- try:
- import accimage
- except ImportError:
- accimage = None
- from ..utils import _log_api_usage_once
- from . import functional as F
- from .functional import _interpolation_modes_from_int, InterpolationMode
- __all__ = [
- "Compose",
- "ToTensor",
- "PILToTensor",
- "ConvertImageDtype",
- "ToPILImage",
- "Normalize",
- "Resize",
- "CenterCrop",
- "Pad",
- "Lambda",
- "RandomApply",
- "RandomChoice",
- "RandomOrder",
- "RandomCrop",
- "RandomHorizontalFlip",
- "RandomVerticalFlip",
- "RandomResizedCrop",
- "FiveCrop",
- "TenCrop",
- "LinearTransformation",
- "ColorJitter",
- "RandomRotation",
- "RandomAffine",
- "Grayscale",
- "RandomGrayscale",
- "RandomPerspective",
- "RandomErasing",
- "GaussianBlur",
- "InterpolationMode",
- "RandomInvert",
- "RandomPosterize",
- "RandomSolarize",
- "RandomAdjustSharpness",
- "RandomAutocontrast",
- "RandomEqualize",
- "ElasticTransform",
- ]
- class Compose:
- """Composes several transforms together. This transform does not support torchscript.
- Please, see the note below.
- Args:
- transforms (list of ``Transform`` objects): list of transforms to compose.
- Example:
- >>> transforms.Compose([
- >>> transforms.CenterCrop(10),
- >>> transforms.PILToTensor(),
- >>> transforms.ConvertImageDtype(torch.float),
- >>> ])
- .. note::
- In order to script the transformations, please use ``torch.nn.Sequential`` as below.
- >>> transforms = torch.nn.Sequential(
- >>> transforms.CenterCrop(10),
- >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- >>> )
- >>> scripted_transforms = torch.jit.script(transforms)
- Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
- `lambda` functions or ``PIL.Image``.
- """
- def __init__(self, transforms):
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(self)
- self.transforms = transforms
- def __call__(self, img):
- for t in self.transforms:
- img = t(img)
- return img
- def __repr__(self) -> str:
- format_string = self.__class__.__name__ + "("
- for t in self.transforms:
- format_string += "\n"
- format_string += f" {t}"
- format_string += "\n)"
- return format_string
- class ToTensor:
- """Convert a PIL Image or ndarray to tensor and scale the values accordingly.
- This transform does not support torchscript.
- Converts a PIL Image or numpy.ndarray (H x W x C) in the range
- [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
- if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
- or if the numpy.ndarray has dtype = np.uint8
- In the other cases, tensors are returned without scaling.
- .. note::
- Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
- transforming target image masks. See the `references`_ for implementing the transforms for image masks.
- .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
- """
- def __init__(self) -> None:
- _log_api_usage_once(self)
- def __call__(self, pic):
- """
- Args:
- pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
- Returns:
- Tensor: Converted image.
- """
- return F.to_tensor(pic)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}()"
- class PILToTensor:
- """Convert a PIL Image to a tensor of the same type - this does not scale values.
- This transform does not support torchscript.
- Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
- """
- def __init__(self) -> None:
- _log_api_usage_once(self)
- def __call__(self, pic):
- """
- .. note::
- A deep copy of the underlying array is performed.
- Args:
- pic (PIL Image): Image to be converted to tensor.
- Returns:
- Tensor: Converted image.
- """
- return F.pil_to_tensor(pic)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}()"
- class ConvertImageDtype(torch.nn.Module):
- """Convert a tensor image to the given ``dtype`` and scale the values accordingly.
- This function does not support PIL Image.
- Args:
- dtype (torch.dtype): Desired data type of the output
- .. note::
- When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
- If converted back and forth, this mismatch has no effect.
- Raises:
- RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
- well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
- overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
- of the integer ``dtype``.
- """
- def __init__(self, dtype: torch.dtype) -> None:
- super().__init__()
- _log_api_usage_once(self)
- self.dtype = dtype
- def forward(self, image):
- return F.convert_image_dtype(image, self.dtype)
- class ToPILImage:
- """Convert a tensor or an ndarray to PIL Image
- This transform does not support torchscript.
- Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
- H x W x C to a PIL Image while adjusting the value range depending on the ``mode``.
- Args:
- mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
- If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
- - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
- - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
- - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
- - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, ``short``).
- .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
- """
- def __init__(self, mode=None):
- _log_api_usage_once(self)
- self.mode = mode
- def __call__(self, pic):
- """
- Args:
- pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
- Returns:
- PIL Image: Image converted to PIL Image.
- """
- return F.to_pil_image(pic, self.mode)
- def __repr__(self) -> str:
- format_string = self.__class__.__name__ + "("
- if self.mode is not None:
- format_string += f"mode={self.mode}"
- format_string += ")"
- return format_string
- class Normalize(torch.nn.Module):
- """Normalize a tensor image with mean and standard deviation.
- This transform does not support PIL Image.
- Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
- channels, this transform will normalize each channel of the input
- ``torch.*Tensor`` i.e.,
- ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
- .. note::
- This transform acts out of place, i.e., it does not mutate the input tensor.
- Args:
- mean (sequence): Sequence of means for each channel.
- std (sequence): Sequence of standard deviations for each channel.
- inplace(bool,optional): Bool to make this operation in-place.
- """
- def __init__(self, mean, std, inplace=False):
- super().__init__()
- _log_api_usage_once(self)
- self.mean = mean
- self.std = std
- self.inplace = inplace
- def forward(self, tensor: Tensor) -> Tensor:
- """
- Args:
- tensor (Tensor): Tensor image to be normalized.
- Returns:
- Tensor: Normalized Tensor image.
- """
- return F.normalize(tensor, self.mean, self.std, self.inplace)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
- class Resize(torch.nn.Module):
- """Resize the input image to the given size.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means a maximum of two leading dimensions
- Args:
- size (sequence or int): Desired output size. If size is a sequence like
- (h, w), output size will be matched to this. If size is an int,
- smaller edge of the image will be matched to this number.
- i.e, if height > width, then image will be rescaled to
- (size * height / width, size).
- .. note::
- In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
- ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
- The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
- max_size (int, optional): The maximum allowed for the longer edge of
- the resized image. If the longer edge of the image is greater
- than ``max_size`` after being resized according to ``size``,
- ``size`` will be overruled so that the longer edge is equal to
- ``max_size``.
- As a result, the smaller edge may be shorter than ``size``. This
- is only supported if ``size`` is an int (or a sequence of length
- 1 in torchscript mode).
- antialias (bool, optional): Whether to apply antialiasing.
- It only affects **tensors** with bilinear or bicubic modes and it is
- ignored otherwise: on PIL images, antialiasing is always applied on
- bilinear or bicubic modes; on other modes (for PIL images and
- tensors), antialiasing makes no sense and this parameter is ignored.
- Possible values are:
- - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
- Other mode aren't affected. This is probably what you want to use.
- - ``False``: will not apply antialiasing for tensors on any mode. PIL
- images are still antialiased on bilinear or bicubic modes, because
- PIL doesn't support no antialias.
- - ``None``: equivalent to ``False`` for tensors and ``True`` for
- PIL images. This value exists for legacy reasons and you probably
- don't want to use it unless you really know what you are doing.
- The default value changed from ``None`` to ``True`` in
- v0.17, for the PIL and Tensor backends to be consistent.
- """
- def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=True):
- super().__init__()
- _log_api_usage_once(self)
- if not isinstance(size, (int, Sequence)):
- raise TypeError(f"Size should be int or sequence. Got {type(size)}")
- if isinstance(size, Sequence) and len(size) not in (1, 2):
- raise ValueError("If size is a sequence, it should have 1 or 2 values")
- self.size = size
- self.max_size = max_size
- if isinstance(interpolation, int):
- interpolation = _interpolation_modes_from_int(interpolation)
- self.interpolation = interpolation
- self.antialias = antialias
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be scaled.
- Returns:
- PIL Image or Tensor: Rescaled image.
- """
- return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
- def __repr__(self) -> str:
- detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})"
- return f"{self.__class__.__name__}{detail}"
- class CenterCrop(torch.nn.Module):
- """Crops the given image at the center.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
- If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
- Args:
- size (sequence or int): Desired output size of the crop. If size is an
- int instead of sequence like (h, w), a square crop (size, size) is
- made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
- """
- def __init__(self, size):
- super().__init__()
- _log_api_usage_once(self)
- self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be cropped.
- Returns:
- PIL Image or Tensor: Cropped image.
- """
- return F.center_crop(img, self.size)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(size={self.size})"
- class Pad(torch.nn.Module):
- """Pad the given image on all sides with the given "pad" value.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
- at most 3 leading dimensions for mode edge,
- and an arbitrary number of leading dimensions for mode constant
- Args:
- padding (int or sequence): Padding on each border. If a single int is provided this
- is used to pad all borders. If sequence of length 2 is provided this is the padding
- on left/right and top/bottom respectively. If a sequence of length 4 is provided
- this is the padding for the left, top, right and bottom borders respectively.
- .. note::
- In torchscript mode padding as single int is not supported, use a sequence of
- length 1: ``[padding, ]``.
- fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
- length 3, it is used to fill R, G, B channels respectively.
- This value is only used when the padding_mode is constant.
- Only number is supported for torch Tensor.
- Only int or tuple value is supported for PIL Image.
- padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
- Default is constant.
- - constant: pads with a constant value, this value is specified with fill
- - edge: pads with the last value at the edge of the image.
- If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
- - reflect: pads with reflection of image without repeating the last value on the edge.
- For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
- will result in [3, 2, 1, 2, 3, 4, 3, 2]
- - symmetric: pads with reflection of image repeating the last value on the edge.
- For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
- will result in [2, 1, 1, 2, 3, 4, 4, 3]
- """
- def __init__(self, padding, fill=0, padding_mode="constant"):
- super().__init__()
- _log_api_usage_once(self)
- if not isinstance(padding, (numbers.Number, tuple, list)):
- raise TypeError("Got inappropriate padding arg")
- if not isinstance(fill, (numbers.Number, tuple, list)):
- raise TypeError("Got inappropriate fill arg")
- if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
- raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
- if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
- raise ValueError(
- f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
- )
- self.padding = padding
- self.fill = fill
- self.padding_mode = padding_mode
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be padded.
- Returns:
- PIL Image or Tensor: Padded image.
- """
- return F.pad(img, self.padding, self.fill, self.padding_mode)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})"
- class Lambda:
- """Apply a user-defined lambda as a transform. This transform does not support torchscript.
- Args:
- lambd (function): Lambda/function to be used for transform.
- """
- def __init__(self, lambd):
- _log_api_usage_once(self)
- if not callable(lambd):
- raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
- self.lambd = lambd
- def __call__(self, img):
- return self.lambd(img)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}()"
- class RandomTransforms:
- """Base class for a list of transformations with randomness
- Args:
- transforms (sequence): list of transformations
- """
- def __init__(self, transforms):
- _log_api_usage_once(self)
- if not isinstance(transforms, Sequence):
- raise TypeError("Argument transforms should be a sequence")
- self.transforms = transforms
- def __call__(self, *args, **kwargs):
- raise NotImplementedError()
- def __repr__(self) -> str:
- format_string = self.__class__.__name__ + "("
- for t in self.transforms:
- format_string += "\n"
- format_string += f" {t}"
- format_string += "\n)"
- return format_string
- class RandomApply(torch.nn.Module):
- """Apply randomly a list of transformations with a given probability.
- .. note::
- In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
- transforms as shown below:
- >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
- >>> transforms.ColorJitter(),
- >>> ]), p=0.3)
- >>> scripted_transforms = torch.jit.script(transforms)
- Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
- `lambda` functions or ``PIL.Image``.
- Args:
- transforms (sequence or torch.nn.Module): list of transformations
- p (float): probability
- """
- def __init__(self, transforms, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.transforms = transforms
- self.p = p
- def forward(self, img):
- if self.p < torch.rand(1):
- return img
- for t in self.transforms:
- img = t(img)
- return img
- def __repr__(self) -> str:
- format_string = self.__class__.__name__ + "("
- format_string += f"\n p={self.p}"
- for t in self.transforms:
- format_string += "\n"
- format_string += f" {t}"
- format_string += "\n)"
- return format_string
- class RandomOrder(RandomTransforms):
- """Apply a list of transformations in a random order. This transform does not support torchscript."""
- def __call__(self, img):
- order = list(range(len(self.transforms)))
- random.shuffle(order)
- for i in order:
- img = self.transforms[i](img)
- return img
- class RandomChoice(RandomTransforms):
- """Apply single transformation randomly picked from a list. This transform does not support torchscript."""
- def __init__(self, transforms, p=None):
- super().__init__(transforms)
- if p is not None and not isinstance(p, Sequence):
- raise TypeError("Argument p should be a sequence")
- self.p = p
- def __call__(self, *args):
- t = random.choices(self.transforms, weights=self.p)[0]
- return t(*args)
- def __repr__(self) -> str:
- return f"{super().__repr__()}(p={self.p})"
- class RandomCrop(torch.nn.Module):
- """Crop the given image at a random location.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions,
- but if non-constant padding is used, the input is expected to have at most 2 leading dimensions
- Args:
- size (sequence or int): Desired output size of the crop. If size is an
- int instead of sequence like (h, w), a square crop (size, size) is
- made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
- padding (int or sequence, optional): Optional padding on each border
- of the image. Default is None. If a single int is provided this
- is used to pad all borders. If sequence of length 2 is provided this is the padding
- on left/right and top/bottom respectively. If a sequence of length 4 is provided
- this is the padding for the left, top, right and bottom borders respectively.
- .. note::
- In torchscript mode padding as single int is not supported, use a sequence of
- length 1: ``[padding, ]``.
- pad_if_needed (boolean): It will pad the image if smaller than the
- desired size to avoid raising an exception. Since cropping is done
- after padding, the padding seems to be done at a random offset.
- fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
- length 3, it is used to fill R, G, B channels respectively.
- This value is only used when the padding_mode is constant.
- Only number is supported for torch Tensor.
- Only int or tuple value is supported for PIL Image.
- padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
- Default is constant.
- - constant: pads with a constant value, this value is specified with fill
- - edge: pads with the last value at the edge of the image.
- If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
- - reflect: pads with reflection of image without repeating the last value on the edge.
- For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
- will result in [3, 2, 1, 2, 3, 4, 3, 2]
- - symmetric: pads with reflection of image repeating the last value on the edge.
- For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
- will result in [2, 1, 1, 2, 3, 4, 4, 3]
- """
- @staticmethod
- def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
- """Get parameters for ``crop`` for a random crop.
- Args:
- img (PIL Image or Tensor): Image to be cropped.
- output_size (tuple): Expected output size of the crop.
- Returns:
- tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
- """
- _, h, w = F.get_dimensions(img)
- th, tw = output_size
- if h < th or w < tw:
- raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
- if w == tw and h == th:
- return 0, 0, h, w
- i = torch.randint(0, h - th + 1, size=(1,)).item()
- j = torch.randint(0, w - tw + 1, size=(1,)).item()
- return i, j, th, tw
- def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
- super().__init__()
- _log_api_usage_once(self)
- self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
- self.padding = padding
- self.pad_if_needed = pad_if_needed
- self.fill = fill
- self.padding_mode = padding_mode
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be cropped.
- Returns:
- PIL Image or Tensor: Cropped image.
- """
- if self.padding is not None:
- img = F.pad(img, self.padding, self.fill, self.padding_mode)
- _, height, width = F.get_dimensions(img)
- # pad the width if needed
- if self.pad_if_needed and width < self.size[1]:
- padding = [self.size[1] - width, 0]
- img = F.pad(img, padding, self.fill, self.padding_mode)
- # pad the height if needed
- if self.pad_if_needed and height < self.size[0]:
- padding = [0, self.size[0] - height]
- img = F.pad(img, padding, self.fill, self.padding_mode)
- i, j, h, w = self.get_params(img, self.size)
- return F.crop(img, i, j, h, w)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(size={self.size}, padding={self.padding})"
- class RandomHorizontalFlip(torch.nn.Module):
- """Horizontally flip the given image randomly with a given probability.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading
- dimensions
- Args:
- p (float): probability of the image being flipped. Default value is 0.5
- """
- def __init__(self, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be flipped.
- Returns:
- PIL Image or Tensor: Randomly flipped image.
- """
- if torch.rand(1) < self.p:
- return F.hflip(img)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class RandomVerticalFlip(torch.nn.Module):
- """Vertically flip the given image randomly with a given probability.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading
- dimensions
- Args:
- p (float): probability of the image being flipped. Default value is 0.5
- """
- def __init__(self, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be flipped.
- Returns:
- PIL Image or Tensor: Randomly flipped image.
- """
- if torch.rand(1) < self.p:
- return F.vflip(img)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class RandomPerspective(torch.nn.Module):
- """Performs a random perspective transformation of the given image with a given probability.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
- Args:
- distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
- Default is 0.5.
- p (float): probability of the image being transformed. Default is 0.5.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
- fill (sequence or number): Pixel fill value for the area outside the transformed
- image. Default is ``0``. If given a number, the value is used for all bands respectively.
- """
- def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- if isinstance(interpolation, int):
- interpolation = _interpolation_modes_from_int(interpolation)
- self.interpolation = interpolation
- self.distortion_scale = distortion_scale
- if fill is None:
- fill = 0
- elif not isinstance(fill, (Sequence, numbers.Number)):
- raise TypeError("Fill should be either a sequence or a number.")
- self.fill = fill
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be Perspectively transformed.
- Returns:
- PIL Image or Tensor: Randomly transformed image.
- """
- fill = self.fill
- channels, height, width = F.get_dimensions(img)
- if isinstance(img, Tensor):
- if isinstance(fill, (int, float)):
- fill = [float(fill)] * channels
- else:
- fill = [float(f) for f in fill]
- if torch.rand(1) < self.p:
- startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
- return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
- return img
- @staticmethod
- def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
- """Get parameters for ``perspective`` for a random perspective transform.
- Args:
- width (int): width of the image.
- height (int): height of the image.
- distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
- Returns:
- List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
- List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
- """
- half_height = height // 2
- half_width = width // 2
- topleft = [
- int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
- int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
- ]
- topright = [
- int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
- int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
- ]
- botright = [
- int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
- int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
- ]
- botleft = [
- int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
- int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
- ]
- startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
- endpoints = [topleft, topright, botright, botleft]
- return startpoints, endpoints
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class RandomResizedCrop(torch.nn.Module):
- """Crop a random portion of image and resize it to a given size.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
- A crop of the original image is made: the crop has a random area (H * W)
- and a random aspect ratio. This crop is finally resized to the given
- size. This is popularly used to train the Inception networks.
- Args:
- size (int or sequence): expected output size of the crop, for each edge. If size is an
- int instead of sequence like (h, w), a square output size ``(size, size)`` is
- made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
- .. note::
- In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
- scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
- before resizing. The scale is defined with respect to the area of the original image.
- ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
- resizing.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
- ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
- The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
- antialias (bool, optional): Whether to apply antialiasing.
- It only affects **tensors** with bilinear or bicubic modes and it is
- ignored otherwise: on PIL images, antialiasing is always applied on
- bilinear or bicubic modes; on other modes (for PIL images and
- tensors), antialiasing makes no sense and this parameter is ignored.
- Possible values are:
- - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
- Other mode aren't affected. This is probably what you want to use.
- - ``False``: will not apply antialiasing for tensors on any mode. PIL
- images are still antialiased on bilinear or bicubic modes, because
- PIL doesn't support no antialias.
- - ``None``: equivalent to ``False`` for tensors and ``True`` for
- PIL images. This value exists for legacy reasons and you probably
- don't want to use it unless you really know what you are doing.
- The default value changed from ``None`` to ``True`` in
- v0.17, for the PIL and Tensor backends to be consistent.
- """
- def __init__(
- self,
- size,
- scale=(0.08, 1.0),
- ratio=(3.0 / 4.0, 4.0 / 3.0),
- interpolation=InterpolationMode.BILINEAR,
- antialias: Optional[bool] = True,
- ):
- super().__init__()
- _log_api_usage_once(self)
- self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
- if not isinstance(scale, Sequence):
- raise TypeError("Scale should be a sequence")
- if not isinstance(ratio, Sequence):
- raise TypeError("Ratio should be a sequence")
- if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
- warnings.warn("Scale and ratio should be of kind (min, max)")
- if isinstance(interpolation, int):
- interpolation = _interpolation_modes_from_int(interpolation)
- self.interpolation = interpolation
- self.antialias = antialias
- self.scale = scale
- self.ratio = ratio
- @staticmethod
- def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
- """Get parameters for ``crop`` for a random sized crop.
- Args:
- img (PIL Image or Tensor): Input image.
- scale (list): range of scale of the origin size cropped
- ratio (list): range of aspect ratio of the origin aspect ratio cropped
- Returns:
- tuple: params (i, j, h, w) to be passed to ``crop`` for a random
- sized crop.
- """
- _, height, width = F.get_dimensions(img)
- area = height * width
- log_ratio = torch.log(torch.tensor(ratio))
- for _ in range(10):
- target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
- aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
- w = int(round(math.sqrt(target_area * aspect_ratio)))
- h = int(round(math.sqrt(target_area / aspect_ratio)))
- if 0 < w <= width and 0 < h <= height:
- i = torch.randint(0, height - h + 1, size=(1,)).item()
- j = torch.randint(0, width - w + 1, size=(1,)).item()
- return i, j, h, w
- # Fallback to central crop
- in_ratio = float(width) / float(height)
- if in_ratio < min(ratio):
- w = width
- h = int(round(w / min(ratio)))
- elif in_ratio > max(ratio):
- h = height
- w = int(round(h * max(ratio)))
- else: # whole image
- w = width
- h = height
- i = (height - h) // 2
- j = (width - w) // 2
- return i, j, h, w
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be cropped and resized.
- Returns:
- PIL Image or Tensor: Randomly cropped and resized image.
- """
- i, j, h, w = self.get_params(img, self.scale, self.ratio)
- return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)
- def __repr__(self) -> str:
- interpolate_str = self.interpolation.value
- format_string = self.__class__.__name__ + f"(size={self.size}"
- format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
- format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
- format_string += f", interpolation={interpolate_str}"
- format_string += f", antialias={self.antialias})"
- return format_string
- class FiveCrop(torch.nn.Module):
- """Crop the given image into four corners and the central crop.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading
- dimensions
- .. Note::
- This transform returns a tuple of images and there may be a mismatch in the number of
- inputs and targets your Dataset returns. See below for an example of how to deal with
- this.
- Args:
- size (sequence or int): Desired output size of the crop. If size is an ``int``
- instead of sequence like (h, w), a square crop of size (size, size) is made.
- If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
- Example:
- >>> transform = Compose([
- >>> FiveCrop(size), # this is a list of PIL Images
- >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
- >>> ])
- >>> #In your test loop you can do the following:
- >>> input, target = batch # input is a 5d tensor, target is 2d
- >>> bs, ncrops, c, h, w = input.size()
- >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
- >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
- """
- def __init__(self, size):
- super().__init__()
- _log_api_usage_once(self)
- self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be cropped.
- Returns:
- tuple of 5 images. Image can be PIL Image or Tensor
- """
- return F.five_crop(img, self.size)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(size={self.size})"
- class TenCrop(torch.nn.Module):
- """Crop the given image into four corners and the central crop plus the flipped version of
- these (horizontal flipping is used by default).
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading
- dimensions
- .. Note::
- This transform returns a tuple of images and there may be a mismatch in the number of
- inputs and targets your Dataset returns. See below for an example of how to deal with
- this.
- Args:
- size (sequence or int): Desired output size of the crop. If size is an
- int instead of sequence like (h, w), a square crop (size, size) is
- made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
- vertical_flip (bool): Use vertical flipping instead of horizontal
- Example:
- >>> transform = Compose([
- >>> TenCrop(size), # this is a tuple of PIL Images
- >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
- >>> ])
- >>> #In your test loop you can do the following:
- >>> input, target = batch # input is a 5d tensor, target is 2d
- >>> bs, ncrops, c, h, w = input.size()
- >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
- >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
- """
- def __init__(self, size, vertical_flip=False):
- super().__init__()
- _log_api_usage_once(self)
- self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
- self.vertical_flip = vertical_flip
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be cropped.
- Returns:
- tuple of 10 images. Image can be PIL Image or Tensor
- """
- return F.ten_crop(img, self.size, self.vertical_flip)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(size={self.size}, vertical_flip={self.vertical_flip})"
- class LinearTransformation(torch.nn.Module):
- """Transform a tensor image with a square transformation matrix and a mean_vector computed
- offline.
- This transform does not support PIL Image.
- Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
- subtract mean_vector from it which is then followed by computing the dot
- product with the transformation matrix and then reshaping the tensor to its
- original shape.
- Applications:
- whitening transformation: Suppose X is a column vector zero-centered data.
- Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
- perform SVD on this matrix and pass it as transformation_matrix.
- Args:
- transformation_matrix (Tensor): tensor [D x D], D = C x H x W
- mean_vector (Tensor): tensor [D], D = C x H x W
- """
- def __init__(self, transformation_matrix, mean_vector):
- super().__init__()
- _log_api_usage_once(self)
- if transformation_matrix.size(0) != transformation_matrix.size(1):
- raise ValueError(
- "transformation_matrix should be square. Got "
- f"{tuple(transformation_matrix.size())} rectangular matrix."
- )
- if mean_vector.size(0) != transformation_matrix.size(0):
- raise ValueError(
- f"mean_vector should have the same length {mean_vector.size(0)}"
- f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
- )
- if transformation_matrix.device != mean_vector.device:
- raise ValueError(
- f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
- )
- if transformation_matrix.dtype != mean_vector.dtype:
- raise ValueError(
- f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
- )
- self.transformation_matrix = transformation_matrix
- self.mean_vector = mean_vector
- def forward(self, tensor: Tensor) -> Tensor:
- """
- Args:
- tensor (Tensor): Tensor image to be whitened.
- Returns:
- Tensor: Transformed image.
- """
- shape = tensor.shape
- n = shape[-3] * shape[-2] * shape[-1]
- if n != self.transformation_matrix.shape[0]:
- raise ValueError(
- "Input tensor and transformation matrix have incompatible shape."
- + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
- + f"{self.transformation_matrix.shape[0]}"
- )
- if tensor.device.type != self.mean_vector.device.type:
- raise ValueError(
- "Input tensor should be on the same device as transformation matrix and mean vector. "
- f"Got {tensor.device} vs {self.mean_vector.device}"
- )
- flat_tensor = tensor.view(-1, n) - self.mean_vector
- transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
- transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
- tensor = transformed_tensor.view(shape)
- return tensor
- def __repr__(self) -> str:
- s = (
- f"{self.__class__.__name__}(transformation_matrix="
- f"{self.transformation_matrix.tolist()}"
- f", mean_vector={self.mean_vector.tolist()})"
- )
- return s
- class ColorJitter(torch.nn.Module):
- """Randomly change the brightness, contrast, saturation and hue of an image.
- If the image is torch Tensor, it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
- Args:
- brightness (float or tuple of float (min, max)): How much to jitter brightness.
- brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
- or the given [min, max]. Should be non negative numbers.
- contrast (float or tuple of float (min, max)): How much to jitter contrast.
- contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
- or the given [min, max]. Should be non-negative numbers.
- saturation (float or tuple of float (min, max)): How much to jitter saturation.
- saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
- or the given [min, max]. Should be non negative numbers.
- hue (float or tuple of float (min, max)): How much to jitter hue.
- hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
- Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
- To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
- thus it does not work if you normalize your image to an interval with negative values,
- or use an interpolation that generates negative values before using this function.
- """
- def __init__(
- self,
- brightness: Union[float, Tuple[float, float]] = 0,
- contrast: Union[float, Tuple[float, float]] = 0,
- saturation: Union[float, Tuple[float, float]] = 0,
- hue: Union[float, Tuple[float, float]] = 0,
- ) -> None:
- super().__init__()
- _log_api_usage_once(self)
- self.brightness = self._check_input(brightness, "brightness")
- self.contrast = self._check_input(contrast, "contrast")
- self.saturation = self._check_input(saturation, "saturation")
- self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
- @torch.jit.unused
- def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
- if isinstance(value, numbers.Number):
- if value < 0:
- raise ValueError(f"If {name} is a single number, it must be non negative.")
- value = [center - float(value), center + float(value)]
- if clip_first_on_zero:
- value[0] = max(value[0], 0.0)
- elif isinstance(value, (tuple, list)) and len(value) == 2:
- value = [float(value[0]), float(value[1])]
- else:
- raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
- if not bound[0] <= value[0] <= value[1] <= bound[1]:
- raise ValueError(f"{name} values should be between {bound}, but got {value}.")
- # if value is 0 or (1., 1.) for brightness/contrast/saturation
- # or (0., 0.) for hue, do nothing
- if value[0] == value[1] == center:
- return None
- else:
- return tuple(value)
- @staticmethod
- def get_params(
- brightness: Optional[List[float]],
- contrast: Optional[List[float]],
- saturation: Optional[List[float]],
- hue: Optional[List[float]],
- ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
- """Get the parameters for the randomized transform to be applied on image.
- Args:
- brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
- uniformly. Pass None to turn off the transformation.
- contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
- uniformly. Pass None to turn off the transformation.
- saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
- uniformly. Pass None to turn off the transformation.
- hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
- Pass None to turn off the transformation.
- Returns:
- tuple: The parameters used to apply the randomized transform
- along with their random order.
- """
- fn_idx = torch.randperm(4)
- b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
- c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
- s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
- h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
- return fn_idx, b, c, s, h
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Input image.
- Returns:
- PIL Image or Tensor: Color jittered image.
- """
- fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
- self.brightness, self.contrast, self.saturation, self.hue
- )
- for fn_id in fn_idx:
- if fn_id == 0 and brightness_factor is not None:
- img = F.adjust_brightness(img, brightness_factor)
- elif fn_id == 1 and contrast_factor is not None:
- img = F.adjust_contrast(img, contrast_factor)
- elif fn_id == 2 and saturation_factor is not None:
- img = F.adjust_saturation(img, saturation_factor)
- elif fn_id == 3 and hue_factor is not None:
- img = F.adjust_hue(img, hue_factor)
- return img
- def __repr__(self) -> str:
- s = (
- f"{self.__class__.__name__}("
- f"brightness={self.brightness}"
- f", contrast={self.contrast}"
- f", saturation={self.saturation}"
- f", hue={self.hue})"
- )
- return s
- class RandomRotation(torch.nn.Module):
- """Rotate the image by angle.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
- Args:
- degrees (sequence or number): Range of degrees to select from.
- If degrees is a number instead of sequence like (min, max), the range of degrees
- will be (-degrees, +degrees).
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
- expand (bool, optional): Optional expansion flag.
- If true, expands the output to make it large enough to hold the entire rotated image.
- If false or omitted, make the output image the same size as the input image.
- Note that the expand flag assumes rotation around the center and no translation.
- center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
- Default is the center of the image.
- fill (sequence or number): Pixel fill value for the area outside the rotated
- image. Default is ``0``. If given a number, the value is used for all bands respectively.
- .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
- """
- def __init__(self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0):
- super().__init__()
- _log_api_usage_once(self)
- if isinstance(interpolation, int):
- interpolation = _interpolation_modes_from_int(interpolation)
- self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
- if center is not None:
- _check_sequence_input(center, "center", req_sizes=(2,))
- self.center = center
- self.interpolation = interpolation
- self.expand = expand
- if fill is None:
- fill = 0
- elif not isinstance(fill, (Sequence, numbers.Number)):
- raise TypeError("Fill should be either a sequence or a number.")
- self.fill = fill
- @staticmethod
- def get_params(degrees: List[float]) -> float:
- """Get parameters for ``rotate`` for a random rotation.
- Returns:
- float: angle parameter to be passed to ``rotate`` for random rotation.
- """
- angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
- return angle
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be rotated.
- Returns:
- PIL Image or Tensor: Rotated image.
- """
- fill = self.fill
- channels, _, _ = F.get_dimensions(img)
- if isinstance(img, Tensor):
- if isinstance(fill, (int, float)):
- fill = [float(fill)] * channels
- else:
- fill = [float(f) for f in fill]
- angle = self.get_params(self.degrees)
- return F.rotate(img, angle, self.interpolation, self.expand, self.center, fill)
- def __repr__(self) -> str:
- interpolate_str = self.interpolation.value
- format_string = self.__class__.__name__ + f"(degrees={self.degrees}"
- format_string += f", interpolation={interpolate_str}"
- format_string += f", expand={self.expand}"
- if self.center is not None:
- format_string += f", center={self.center}"
- if self.fill is not None:
- format_string += f", fill={self.fill}"
- format_string += ")"
- return format_string
- class RandomAffine(torch.nn.Module):
- """Random affine transformation of the image keeping center invariant.
- If the image is torch Tensor, it is expected
- to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
- Args:
- degrees (sequence or number): Range of degrees to select from.
- If degrees is a number instead of sequence like (min, max), the range of degrees
- will be (-degrees, +degrees). Set to 0 to deactivate rotations.
- translate (tuple, optional): tuple of maximum absolute fraction for horizontal
- and vertical translations. For example translate=(a, b), then horizontal shift
- is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
- randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
- scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
- randomly sampled from the range a <= scale <= b. Will keep original scale by default.
- shear (sequence or number, optional): Range of degrees to select from.
- If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
- will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
- range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
- an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
- Will not apply shear by default.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
- fill (sequence or number): Pixel fill value for the area outside the transformed
- image. Default is ``0``. If given a number, the value is used for all bands respectively.
- center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
- Default is the center of the image.
- .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
- """
- def __init__(
- self,
- degrees,
- translate=None,
- scale=None,
- shear=None,
- interpolation=InterpolationMode.NEAREST,
- fill=0,
- center=None,
- ):
- super().__init__()
- _log_api_usage_once(self)
- if isinstance(interpolation, int):
- interpolation = _interpolation_modes_from_int(interpolation)
- self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
- if translate is not None:
- _check_sequence_input(translate, "translate", req_sizes=(2,))
- for t in translate:
- if not (0.0 <= t <= 1.0):
- raise ValueError("translation values should be between 0 and 1")
- self.translate = translate
- if scale is not None:
- _check_sequence_input(scale, "scale", req_sizes=(2,))
- for s in scale:
- if s <= 0:
- raise ValueError("scale values should be positive")
- self.scale = scale
- if shear is not None:
- self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
- else:
- self.shear = shear
- self.interpolation = interpolation
- if fill is None:
- fill = 0
- elif not isinstance(fill, (Sequence, numbers.Number)):
- raise TypeError("Fill should be either a sequence or a number.")
- self.fill = fill
- if center is not None:
- _check_sequence_input(center, "center", req_sizes=(2,))
- self.center = center
- @staticmethod
- def get_params(
- degrees: List[float],
- translate: Optional[List[float]],
- scale_ranges: Optional[List[float]],
- shears: Optional[List[float]],
- img_size: List[int],
- ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
- """Get parameters for affine transformation
- Returns:
- params to be passed to the affine transformation
- """
- angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
- if translate is not None:
- max_dx = float(translate[0] * img_size[0])
- max_dy = float(translate[1] * img_size[1])
- tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
- ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
- translations = (tx, ty)
- else:
- translations = (0, 0)
- if scale_ranges is not None:
- scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
- else:
- scale = 1.0
- shear_x = shear_y = 0.0
- if shears is not None:
- shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
- if len(shears) == 4:
- shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
- shear = (shear_x, shear_y)
- return angle, translations, scale, shear
- def forward(self, img):
- """
- img (PIL Image or Tensor): Image to be transformed.
- Returns:
- PIL Image or Tensor: Affine transformed image.
- """
- fill = self.fill
- channels, height, width = F.get_dimensions(img)
- if isinstance(img, Tensor):
- if isinstance(fill, (int, float)):
- fill = [float(fill)] * channels
- else:
- fill = [float(f) for f in fill]
- img_size = [width, height] # flip for keeping BC on get_params call
- ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
- return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
- def __repr__(self) -> str:
- s = f"{self.__class__.__name__}(degrees={self.degrees}"
- s += f", translate={self.translate}" if self.translate is not None else ""
- s += f", scale={self.scale}" if self.scale is not None else ""
- s += f", shear={self.shear}" if self.shear is not None else ""
- s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else ""
- s += f", fill={self.fill}" if self.fill != 0 else ""
- s += f", center={self.center}" if self.center is not None else ""
- s += ")"
- return s
- class Grayscale(torch.nn.Module):
- """Convert image to grayscale.
- If the image is torch Tensor, it is expected
- to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
- Args:
- num_output_channels (int): (1 or 3) number of channels desired for output image
- Returns:
- PIL Image: Grayscale version of the input.
- - If ``num_output_channels == 1`` : returned image is single channel
- - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
- """
- def __init__(self, num_output_channels=1):
- super().__init__()
- _log_api_usage_once(self)
- self.num_output_channels = num_output_channels
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be converted to grayscale.
- Returns:
- PIL Image or Tensor: Grayscaled image.
- """
- return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})"
- class RandomGrayscale(torch.nn.Module):
- """Randomly convert image to grayscale with a probability of p (default 0.1).
- If the image is torch Tensor, it is expected
- to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
- Args:
- p (float): probability that image should be converted to grayscale.
- Returns:
- PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
- with probability (1-p).
- - If input image is 1 channel: grayscale version is 1 channel
- - If input image is 3 channel: grayscale version is 3 channel with r == g == b
- """
- def __init__(self, p=0.1):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be converted to grayscale.
- Returns:
- PIL Image or Tensor: Randomly grayscaled image.
- """
- num_output_channels, _, _ = F.get_dimensions(img)
- if torch.rand(1) < self.p:
- return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class RandomErasing(torch.nn.Module):
- """Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
- This transform does not support PIL Image.
- 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
- Args:
- p: probability that the random erasing operation will be performed.
- scale: range of proportion of erased area against input image.
- ratio: range of aspect ratio of erased area.
- value: erasing value. Default is 0. If a single int, it is used to
- erase all pixels. If a tuple of length 3, it is used to erase
- R, G, B channels respectively.
- If a str of 'random', erasing each pixel with random values.
- inplace: boolean to make this transform inplace. Default set to False.
- Returns:
- Erased Image.
- Example:
- >>> transform = transforms.Compose([
- >>> transforms.RandomHorizontalFlip(),
- >>> transforms.PILToTensor(),
- >>> transforms.ConvertImageDtype(torch.float),
- >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- >>> transforms.RandomErasing(),
- >>> ])
- """
- def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
- super().__init__()
- _log_api_usage_once(self)
- if not isinstance(value, (numbers.Number, str, tuple, list)):
- raise TypeError("Argument value should be either a number or str or a sequence")
- if isinstance(value, str) and value != "random":
- raise ValueError("If value is str, it should be 'random'")
- if not isinstance(scale, (tuple, list)):
- raise TypeError("Scale should be a sequence")
- if not isinstance(ratio, (tuple, list)):
- raise TypeError("Ratio should be a sequence")
- if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
- warnings.warn("Scale and ratio should be of kind (min, max)")
- if scale[0] < 0 or scale[1] > 1:
- raise ValueError("Scale should be between 0 and 1")
- if p < 0 or p > 1:
- raise ValueError("Random erasing probability should be between 0 and 1")
- self.p = p
- self.scale = scale
- self.ratio = ratio
- self.value = value
- self.inplace = inplace
- @staticmethod
- def get_params(
- img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
- ) -> Tuple[int, int, int, int, Tensor]:
- """Get parameters for ``erase`` for a random erasing.
- Args:
- img (Tensor): Tensor image to be erased.
- scale (sequence): range of proportion of erased area against input image.
- ratio (sequence): range of aspect ratio of erased area.
- value (list, optional): erasing value. If None, it is interpreted as "random"
- (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
- i.e. ``value[0]``.
- Returns:
- tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
- """
- img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
- area = img_h * img_w
- log_ratio = torch.log(torch.tensor(ratio))
- for _ in range(10):
- erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
- aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
- h = int(round(math.sqrt(erase_area * aspect_ratio)))
- w = int(round(math.sqrt(erase_area / aspect_ratio)))
- if not (h < img_h and w < img_w):
- continue
- if value is None:
- v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
- else:
- v = torch.tensor(value)[:, None, None]
- i = torch.randint(0, img_h - h + 1, size=(1,)).item()
- j = torch.randint(0, img_w - w + 1, size=(1,)).item()
- return i, j, h, w, v
- # Return original image
- return 0, 0, img_h, img_w, img
- def forward(self, img):
- """
- Args:
- img (Tensor): Tensor image to be erased.
- Returns:
- img (Tensor): Erased Tensor image.
- """
- if torch.rand(1) < self.p:
- # cast self.value to script acceptable type
- if isinstance(self.value, (int, float)):
- value = [float(self.value)]
- elif isinstance(self.value, str):
- value = None
- elif isinstance(self.value, (list, tuple)):
- value = [float(v) for v in self.value]
- else:
- value = self.value
- if value is not None and not (len(value) in (1, img.shape[-3])):
- raise ValueError(
- "If value is a sequence, it should have either a single value or "
- f"{img.shape[-3]} (number of input channels)"
- )
- x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
- return F.erase(img, x, y, h, w, v, self.inplace)
- return img
- def __repr__(self) -> str:
- s = (
- f"{self.__class__.__name__}"
- f"(p={self.p}, "
- f"scale={self.scale}, "
- f"ratio={self.ratio}, "
- f"value={self.value}, "
- f"inplace={self.inplace})"
- )
- return s
- class GaussianBlur(torch.nn.Module):
- """Blurs image with randomly chosen Gaussian blur.
- If the image is torch Tensor, it is expected
- to have [..., C, H, W] shape, where ... means at most one leading dimension.
- Args:
- kernel_size (int or sequence): Size of the Gaussian kernel.
- sigma (float or tuple of float (min, max)): Standard deviation to be used for
- creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
- of float (min, max), sigma is chosen uniformly at random to lie in the
- given range.
- Returns:
- PIL Image or Tensor: Gaussian blurred version of the input image.
- """
- def __init__(self, kernel_size, sigma=(0.1, 2.0)):
- super().__init__()
- _log_api_usage_once(self)
- self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
- for ks in self.kernel_size:
- if ks <= 0 or ks % 2 == 0:
- raise ValueError("Kernel size value should be an odd and positive number.")
- if isinstance(sigma, numbers.Number):
- if sigma <= 0:
- raise ValueError("If sigma is a single number, it must be positive.")
- sigma = (sigma, sigma)
- elif isinstance(sigma, Sequence) and len(sigma) == 2:
- if not 0.0 < sigma[0] <= sigma[1]:
- raise ValueError("sigma values should be positive and of the form (min, max).")
- else:
- raise ValueError("sigma should be a single number or a list/tuple with length 2.")
- self.sigma = sigma
- @staticmethod
- def get_params(sigma_min: float, sigma_max: float) -> float:
- """Choose sigma for random gaussian blurring.
- Args:
- sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
- sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.
- Returns:
- float: Standard deviation to be passed to calculate kernel for gaussian blurring.
- """
- return torch.empty(1).uniform_(sigma_min, sigma_max).item()
- def forward(self, img: Tensor) -> Tensor:
- """
- Args:
- img (PIL Image or Tensor): image to be blurred.
- Returns:
- PIL Image or Tensor: Gaussian blurred image
- """
- sigma = self.get_params(self.sigma[0], self.sigma[1])
- return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])
- def __repr__(self) -> str:
- s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})"
- return s
- def _setup_size(size, error_msg):
- if isinstance(size, numbers.Number):
- return int(size), int(size)
- if isinstance(size, Sequence) and len(size) == 1:
- return size[0], size[0]
- if len(size) != 2:
- raise ValueError(error_msg)
- return size
- def _check_sequence_input(x, name, req_sizes):
- msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
- if not isinstance(x, Sequence):
- raise TypeError(f"{name} should be a sequence of length {msg}.")
- if len(x) not in req_sizes:
- raise ValueError(f"{name} should be a sequence of length {msg}.")
- def _setup_angle(x, name, req_sizes=(2,)):
- if isinstance(x, numbers.Number):
- if x < 0:
- raise ValueError(f"If {name} is a single number, it must be positive.")
- x = [-x, x]
- else:
- _check_sequence_input(x, name, req_sizes)
- return [float(d) for d in x]
- class RandomInvert(torch.nn.Module):
- """Inverts the colors of the given image randomly with a given probability.
- If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
- where ... means it can have an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- p (float): probability of the image being color inverted. Default value is 0.5
- """
- def __init__(self, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be inverted.
- Returns:
- PIL Image or Tensor: Randomly color inverted image.
- """
- if torch.rand(1).item() < self.p:
- return F.invert(img)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class RandomPosterize(torch.nn.Module):
- """Posterize the image randomly with a given probability by reducing the
- number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
- and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- bits (int): number of bits to keep for each channel (0-8)
- p (float): probability of the image being posterized. Default value is 0.5
- """
- def __init__(self, bits, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.bits = bits
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be posterized.
- Returns:
- PIL Image or Tensor: Randomly posterized image.
- """
- if torch.rand(1).item() < self.p:
- return F.posterize(img, self.bits)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(bits={self.bits},p={self.p})"
- class RandomSolarize(torch.nn.Module):
- """Solarize the image randomly with a given probability by inverting all pixel
- values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
- where ... means it can have an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- threshold (float): all pixels equal or above this value are inverted.
- p (float): probability of the image being solarized. Default value is 0.5
- """
- def __init__(self, threshold, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.threshold = threshold
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be solarized.
- Returns:
- PIL Image or Tensor: Randomly solarized image.
- """
- if torch.rand(1).item() < self.p:
- return F.solarize(img, self.threshold)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(threshold={self.threshold},p={self.p})"
- class RandomAdjustSharpness(torch.nn.Module):
- """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
- it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- Args:
- sharpness_factor (float): How much to adjust the sharpness. Can be
- any non-negative number. 0 gives a blurred image, 1 gives the
- original image while 2 increases the sharpness by a factor of 2.
- p (float): probability of the image being sharpened. Default value is 0.5
- """
- def __init__(self, sharpness_factor, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.sharpness_factor = sharpness_factor
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be sharpened.
- Returns:
- PIL Image or Tensor: Randomly sharpened image.
- """
- if torch.rand(1).item() < self.p:
- return F.adjust_sharpness(img, self.sharpness_factor)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})"
- class RandomAutocontrast(torch.nn.Module):
- """Autocontrast the pixels of the given image randomly with a given probability.
- If the image is torch Tensor, it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "L" or "RGB".
- Args:
- p (float): probability of the image being autocontrasted. Default value is 0.5
- """
- def __init__(self, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be autocontrasted.
- Returns:
- PIL Image or Tensor: Randomly autocontrasted image.
- """
- if torch.rand(1).item() < self.p:
- return F.autocontrast(img)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class RandomEqualize(torch.nn.Module):
- """Equalize the histogram of the given image randomly with a given probability.
- If the image is torch Tensor, it is expected
- to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
- If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
- Args:
- p (float): probability of the image being equalized. Default value is 0.5
- """
- def __init__(self, p=0.5):
- super().__init__()
- _log_api_usage_once(self)
- self.p = p
- def forward(self, img):
- """
- Args:
- img (PIL Image or Tensor): Image to be equalized.
- Returns:
- PIL Image or Tensor: Randomly equalized image.
- """
- if torch.rand(1).item() < self.p:
- return F.equalize(img)
- return img
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(p={self.p})"
- class ElasticTransform(torch.nn.Module):
- """Transform a tensor image with elastic transformations.
- Given alpha and sigma, it will generate displacement
- vectors for all pixels based on random offsets. Alpha controls the strength
- and sigma controls the smoothness of the displacements.
- The displacements are added to an identity grid and the resulting grid is
- used to grid_sample from the image.
- Applications:
- Randomly transforms the morphology of objects in images and produces a
- see-through-water-like effect.
- Args:
- alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0.
- sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0.
- interpolation (InterpolationMode): Desired interpolation enum defined by
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
- The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
- fill (sequence or number): Pixel fill value for the area outside the transformed
- image. Default is ``0``. If given a number, the value is used for all bands respectively.
- """
- def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0):
- super().__init__()
- _log_api_usage_once(self)
- if not isinstance(alpha, (float, Sequence)):
- raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}")
- if isinstance(alpha, Sequence) and len(alpha) != 2:
- raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}")
- if isinstance(alpha, Sequence):
- for element in alpha:
- if not isinstance(element, float):
- raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}")
- if isinstance(alpha, float):
- alpha = [float(alpha), float(alpha)]
- if isinstance(alpha, (list, tuple)) and len(alpha) == 1:
- alpha = [alpha[0], alpha[0]]
- self.alpha = alpha
- if not isinstance(sigma, (float, Sequence)):
- raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}")
- if isinstance(sigma, Sequence) and len(sigma) != 2:
- raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}")
- if isinstance(sigma, Sequence):
- for element in sigma:
- if not isinstance(element, float):
- raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}")
- if isinstance(sigma, float):
- sigma = [float(sigma), float(sigma)]
- if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
- sigma = [sigma[0], sigma[0]]
- self.sigma = sigma
- if isinstance(interpolation, int):
- interpolation = _interpolation_modes_from_int(interpolation)
- self.interpolation = interpolation
- if isinstance(fill, (int, float)):
- fill = [float(fill)]
- elif isinstance(fill, (list, tuple)):
- fill = [float(f) for f in fill]
- else:
- raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}")
- self.fill = fill
- @staticmethod
- def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor:
- dx = torch.rand([1, 1] + size) * 2 - 1
- if sigma[0] > 0.0:
- kx = int(8 * sigma[0] + 1)
- # if kernel size is even we have to make it odd
- if kx % 2 == 0:
- kx += 1
- dx = F.gaussian_blur(dx, [kx, kx], sigma)
- dx = dx * alpha[0] / size[0]
- dy = torch.rand([1, 1] + size) * 2 - 1
- if sigma[1] > 0.0:
- ky = int(8 * sigma[1] + 1)
- # if kernel size is even we have to make it odd
- if ky % 2 == 0:
- ky += 1
- dy = F.gaussian_blur(dy, [ky, ky], sigma)
- dy = dy * alpha[1] / size[1]
- return torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
- def forward(self, tensor: Tensor) -> Tensor:
- """
- Args:
- tensor (PIL Image or Tensor): Image to be transformed.
- Returns:
- PIL Image or Tensor: Transformed image.
- """
- _, height, width = F.get_dimensions(tensor)
- displacement = self.get_params(self.alpha, self.sigma, [height, width])
- return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)
- def __repr__(self):
- format_string = self.__class__.__name__
- format_string += f"(alpha={self.alpha}"
- format_string += f", sigma={self.sigma}"
- format_string += f", interpolation={self.interpolation}"
- format_string += f", fill={self.fill})"
- return format_string
|