| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601 | 
							- from typing import Dict, List, Optional, Tuple, Union
 
- import torch
 
- import torchvision
 
- from torch import nn, Tensor
 
- from torchvision import ops
 
- from torchvision.transforms import functional as F, InterpolationMode, transforms as T
 
- def _flip_coco_person_keypoints(kps, width):
 
-     flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
 
-     flipped_data = kps[:, flip_inds]
 
-     flipped_data[..., 0] = width - flipped_data[..., 0]
 
-     # Maintain COCO convention that if visibility == 0, then x, y = 0
 
-     inds = flipped_data[..., 2] == 0
 
-     flipped_data[inds] = 0
 
-     return flipped_data
 
- class Compose:
 
-     def __init__(self, transforms):
 
-         self.transforms = transforms
 
-     def __call__(self, image, target):
 
-         for t in self.transforms:
 
-             image, target = t(image, target)
 
-         return image, target
 
- class RandomHorizontalFlip(T.RandomHorizontalFlip):
 
-     def forward(
 
-         self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
 
-     ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
 
-         if torch.rand(1) < self.p:
 
-             image = F.hflip(image)
 
-             if target is not None:
 
-                 _, _, width = F.get_dimensions(image)
 
-                 target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
 
-                 if "masks" in target:
 
-                     target["masks"] = target["masks"].flip(-1)
 
-                 if "keypoints" in target:
 
-                     keypoints = target["keypoints"]
 
-                     keypoints = _flip_coco_person_keypoints(keypoints, width)
 
-                     target["keypoints"] = keypoints
 
-         return image, target
 
- class PILToTensor(nn.Module):
 
-     def forward(
 
-         self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
 
-     ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
 
-         image = F.pil_to_tensor(image)
 
-         return image, target
 
- class ToDtype(nn.Module):
 
-     def __init__(self, dtype: torch.dtype, scale: bool = False) -> None:
 
-         super().__init__()
 
-         self.dtype = dtype
 
-         self.scale = scale
 
-     def forward(
 
-         self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
 
-     ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
 
-         if not self.scale:
 
-             return image.to(dtype=self.dtype), target
 
-         image = F.convert_image_dtype(image, self.dtype)
 
-         return image, target
 
- class RandomIoUCrop(nn.Module):
 
-     def __init__(
 
-         self,
 
-         min_scale: float = 0.3,
 
-         max_scale: float = 1.0,
 
-         min_aspect_ratio: float = 0.5,
 
-         max_aspect_ratio: float = 2.0,
 
-         sampler_options: Optional[List[float]] = None,
 
-         trials: int = 40,
 
-     ):
 
-         super().__init__()
 
-         # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
 
-         self.min_scale = min_scale
 
-         self.max_scale = max_scale
 
-         self.min_aspect_ratio = min_aspect_ratio
 
-         self.max_aspect_ratio = max_aspect_ratio
 
-         if sampler_options is None:
 
-             sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
 
-         self.options = sampler_options
 
-         self.trials = trials
 
-     def forward(
 
-         self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
 
-     ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
 
-         if target is None:
 
-             raise ValueError("The targets can't be None for this transform.")
 
-         if isinstance(image, torch.Tensor):
 
-             if image.ndimension() not in {2, 3}:
 
-                 raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
 
-             elif image.ndimension() == 2:
 
-                 image = image.unsqueeze(0)
 
-         _, orig_h, orig_w = F.get_dimensions(image)
 
-         while True:
 
-             # sample an option
 
-             idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
 
-             min_jaccard_overlap = self.options[idx]
 
-             if min_jaccard_overlap >= 1.0:  # a value larger than 1 encodes the leave as-is option
 
-                 return image, target
 
-             for _ in range(self.trials):
 
-                 # check the aspect ratio limitations
 
-                 r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
 
-                 new_w = int(orig_w * r[0])
 
-                 new_h = int(orig_h * r[1])
 
-                 aspect_ratio = new_w / new_h
 
-                 if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
 
-                     continue
 
-                 # check for 0 area crops
 
-                 r = torch.rand(2)
 
-                 left = int((orig_w - new_w) * r[0])
 
-                 top = int((orig_h - new_h) * r[1])
 
-                 right = left + new_w
 
-                 bottom = top + new_h
 
-                 if left == right or top == bottom:
 
-                     continue
 
-                 # check for any valid boxes with centers within the crop area
 
-                 cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
 
-                 cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
 
-                 is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
 
-                 if not is_within_crop_area.any():
 
-                     continue
 
-                 # check at least 1 box with jaccard limitations
 
-                 boxes = target["boxes"][is_within_crop_area]
 
-                 ious = torchvision.ops.boxes.box_iou(
 
-                     boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device)
 
-                 )
 
-                 if ious.max() < min_jaccard_overlap:
 
-                     continue
 
-                 # keep only valid boxes and perform cropping
 
-                 target["boxes"] = boxes
 
-                 target["labels"] = target["labels"][is_within_crop_area]
 
-                 target["boxes"][:, 0::2] -= left
 
-                 target["boxes"][:, 1::2] -= top
 
-                 target["boxes"][:, 0::2].clamp_(min=0, max=new_w)
 
-                 target["boxes"][:, 1::2].clamp_(min=0, max=new_h)
 
-                 image = F.crop(image, top, left, new_h, new_w)
 
-                 return image, target
 
- class RandomZoomOut(nn.Module):
 
-     def __init__(
 
-         self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
 
-     ):
 
-         super().__init__()
 
-         if fill is None:
 
-             fill = [0.0, 0.0, 0.0]
 
-         self.fill = fill
 
-         self.side_range = side_range
 
-         if side_range[0] < 1.0 or side_range[0] > side_range[1]:
 
-             raise ValueError(f"Invalid canvas side range provided {side_range}.")
 
-         self.p = p
 
-     @torch.jit.unused
 
-     def _get_fill_value(self, is_pil):
 
-         # type: (bool) -> int
 
-         # We fake the type to make it work on JIT
 
-         return tuple(int(x) for x in self.fill) if is_pil else 0
 
-     def forward(
 
-         self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
 
-     ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
 
-         if isinstance(image, torch.Tensor):
 
-             if image.ndimension() not in {2, 3}:
 
-                 raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
 
-             elif image.ndimension() == 2:
 
-                 image = image.unsqueeze(0)
 
-         if torch.rand(1) >= self.p:
 
-             return image, target
 
-         _, orig_h, orig_w = F.get_dimensions(image)
 
-         r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
 
-         canvas_width = int(orig_w * r)
 
-         canvas_height = int(orig_h * r)
 
-         r = torch.rand(2)
 
-         left = int((canvas_width - orig_w) * r[0])
 
-         top = int((canvas_height - orig_h) * r[1])
 
-         right = canvas_width - (left + orig_w)
 
-         bottom = canvas_height - (top + orig_h)
 
-         if torch.jit.is_scripting():
 
-             fill = 0
 
-         else:
 
-             fill = self._get_fill_value(F._is_pil_image(image))
 
-         image = F.pad(image, [left, top, right, bottom], fill=fill)
 
-         if isinstance(image, torch.Tensor):
 
-             # PyTorch's pad supports only integers on fill. So we need to overwrite the colour
 
-             v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
 
-             image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[
 
-                 ..., :, (left + orig_w) :
 
-             ] = v
 
-         if target is not None:
 
-             target["boxes"][:, 0::2] += left
 
-             target["boxes"][:, 1::2] += top
 
-         return image, target
 
- class RandomPhotometricDistort(nn.Module):
 
-     def __init__(
 
-         self,
 
-         contrast: Tuple[float, float] = (0.5, 1.5),
 
-         saturation: Tuple[float, float] = (0.5, 1.5),
 
-         hue: Tuple[float, float] = (-0.05, 0.05),
 
-         brightness: Tuple[float, float] = (0.875, 1.125),
 
-         p: float = 0.5,
 
-     ):
 
-         super().__init__()
 
-         self._brightness = T.ColorJitter(brightness=brightness)
 
-         self._contrast = T.ColorJitter(contrast=contrast)
 
-         self._hue = T.ColorJitter(hue=hue)
 
-         self._saturation = T.ColorJitter(saturation=saturation)
 
-         self.p = p
 
-     def forward(
 
-         self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
 
-     ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
 
-         if isinstance(image, torch.Tensor):
 
-             if image.ndimension() not in {2, 3}:
 
-                 raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
 
-             elif image.ndimension() == 2:
 
-                 image = image.unsqueeze(0)
 
-         r = torch.rand(7)
 
-         if r[0] < self.p:
 
-             image = self._brightness(image)
 
-         contrast_before = r[1] < 0.5
 
-         if contrast_before:
 
-             if r[2] < self.p:
 
-                 image = self._contrast(image)
 
-         if r[3] < self.p:
 
-             image = self._saturation(image)
 
-         if r[4] < self.p:
 
-             image = self._hue(image)
 
-         if not contrast_before:
 
-             if r[5] < self.p:
 
-                 image = self._contrast(image)
 
-         if r[6] < self.p:
 
-             channels, _, _ = F.get_dimensions(image)
 
-             permutation = torch.randperm(channels)
 
-             is_pil = F._is_pil_image(image)
 
-             if is_pil:
 
-                 image = F.pil_to_tensor(image)
 
-                 image = F.convert_image_dtype(image)
 
-             image = image[..., permutation, :, :]
 
-             if is_pil:
 
-                 image = F.to_pil_image(image)
 
-         return image, target
 
- class ScaleJitter(nn.Module):
 
-     """Randomly resizes the image and its bounding boxes  within the specified scale range.
 
-     The class implements the Scale Jitter augmentation as described in the paper
 
-     `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
 
-     Args:
 
-         target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
 
-         scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
 
-             range a <= scale <= b.
 
-         interpolation (InterpolationMode): Desired interpolation enum defined by
 
-             :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
 
-     """
 
-     def __init__(
 
-         self,
 
-         target_size: Tuple[int, int],
 
-         scale_range: Tuple[float, float] = (0.1, 2.0),
 
-         interpolation: InterpolationMode = InterpolationMode.BILINEAR,
 
-         antialias=True,
 
-     ):
 
-         super().__init__()
 
-         self.target_size = target_size
 
-         self.scale_range = scale_range
 
-         self.interpolation = interpolation
 
-         self.antialias = antialias
 
-     def forward(
 
-         self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
 
-     ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
 
-         if isinstance(image, torch.Tensor):
 
-             if image.ndimension() not in {2, 3}:
 
-                 raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
 
-             elif image.ndimension() == 2:
 
-                 image = image.unsqueeze(0)
 
-         _, orig_height, orig_width = F.get_dimensions(image)
 
-         scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
 
-         r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
 
-         new_width = int(orig_width * r)
 
-         new_height = int(orig_height * r)
 
-         image = F.resize(image, [new_height, new_width], interpolation=self.interpolation, antialias=self.antialias)
 
-         if target is not None:
 
-             target["boxes"][:, 0::2] *= new_width / orig_width
 
-             target["boxes"][:, 1::2] *= new_height / orig_height
 
-             if "masks" in target:
 
-                 target["masks"] = F.resize(
 
-                     target["masks"],
 
-                     [new_height, new_width],
 
-                     interpolation=InterpolationMode.NEAREST,
 
-                     antialias=self.antialias,
 
-                 )
 
-         return image, target
 
- class FixedSizeCrop(nn.Module):
 
-     def __init__(self, size, fill=0, padding_mode="constant"):
 
-         super().__init__()
 
-         size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
 
-         self.crop_height = size[0]
 
-         self.crop_width = size[1]
 
-         self.fill = fill  # TODO: Fill is currently respected only on PIL. Apply tensor patch.
 
-         self.padding_mode = padding_mode
 
-     def _pad(self, img, target, padding):
 
-         # Taken from the functional_tensor.py pad
 
-         if isinstance(padding, int):
 
-             pad_left = pad_right = pad_top = pad_bottom = padding
 
-         elif len(padding) == 1:
 
-             pad_left = pad_right = pad_top = pad_bottom = padding[0]
 
-         elif len(padding) == 2:
 
-             pad_left = pad_right = padding[0]
 
-             pad_top = pad_bottom = padding[1]
 
-         else:
 
-             pad_left = padding[0]
 
-             pad_top = padding[1]
 
-             pad_right = padding[2]
 
-             pad_bottom = padding[3]
 
-         padding = [pad_left, pad_top, pad_right, pad_bottom]
 
-         img = F.pad(img, padding, self.fill, self.padding_mode)
 
-         if target is not None:
 
-             target["boxes"][:, 0::2] += pad_left
 
-             target["boxes"][:, 1::2] += pad_top
 
-             if "masks" in target:
 
-                 target["masks"] = F.pad(target["masks"], padding, 0, "constant")
 
-         return img, target
 
-     def _crop(self, img, target, top, left, height, width):
 
-         img = F.crop(img, top, left, height, width)
 
-         if target is not None:
 
-             boxes = target["boxes"]
 
-             boxes[:, 0::2] -= left
 
-             boxes[:, 1::2] -= top
 
-             boxes[:, 0::2].clamp_(min=0, max=width)
 
-             boxes[:, 1::2].clamp_(min=0, max=height)
 
-             is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3])
 
-             target["boxes"] = boxes[is_valid]
 
-             target["labels"] = target["labels"][is_valid]
 
-             if "masks" in target:
 
-                 target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width)
 
-         return img, target
 
-     def forward(self, img, target=None):
 
-         _, height, width = F.get_dimensions(img)
 
-         new_height = min(height, self.crop_height)
 
-         new_width = min(width, self.crop_width)
 
-         if new_height != height or new_width != width:
 
-             offset_height = max(height - self.crop_height, 0)
 
-             offset_width = max(width - self.crop_width, 0)
 
-             r = torch.rand(1)
 
-             top = int(offset_height * r)
 
-             left = int(offset_width * r)
 
-             img, target = self._crop(img, target, top, left, new_height, new_width)
 
-         pad_bottom = max(self.crop_height - new_height, 0)
 
-         pad_right = max(self.crop_width - new_width, 0)
 
-         if pad_bottom != 0 or pad_right != 0:
 
-             img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
 
-         return img, target
 
- class RandomShortestSize(nn.Module):
 
-     def __init__(
 
-         self,
 
-         min_size: Union[List[int], Tuple[int], int],
 
-         max_size: int,
 
-         interpolation: InterpolationMode = InterpolationMode.BILINEAR,
 
-     ):
 
-         super().__init__()
 
-         self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
 
-         self.max_size = max_size
 
-         self.interpolation = interpolation
 
-     def forward(
 
-         self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
 
-     ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
 
-         _, orig_height, orig_width = F.get_dimensions(image)
 
-         min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()]
 
-         r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
 
-         new_width = int(orig_width * r)
 
-         new_height = int(orig_height * r)
 
-         image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
 
-         if target is not None:
 
-             target["boxes"][:, 0::2] *= new_width / orig_width
 
-             target["boxes"][:, 1::2] *= new_height / orig_height
 
-             if "masks" in target:
 
-                 target["masks"] = F.resize(
 
-                     target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
 
-                 )
 
-         return image, target
 
- def _copy_paste(
 
-     image: torch.Tensor,
 
-     target: Dict[str, Tensor],
 
-     paste_image: torch.Tensor,
 
-     paste_target: Dict[str, Tensor],
 
-     blending: bool = True,
 
-     resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
 
- ) -> Tuple[torch.Tensor, Dict[str, Tensor]]:
 
-     # Random paste targets selection:
 
-     num_masks = len(paste_target["masks"])
 
-     if num_masks < 1:
 
-         # Such degerante case with num_masks=0 can happen with LSJ
 
-         # Let's just return (image, target)
 
-         return image, target
 
-     # We have to please torch script by explicitly specifying dtype as torch.long
 
-     random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
 
-     random_selection = torch.unique(random_selection).to(torch.long)
 
-     paste_masks = paste_target["masks"][random_selection]
 
-     paste_boxes = paste_target["boxes"][random_selection]
 
-     paste_labels = paste_target["labels"][random_selection]
 
-     masks = target["masks"]
 
-     # We resize source and paste data if they have different sizes
 
-     # This is something we introduced here as originally the algorithm works
 
-     # on equal-sized data (for example, coming from LSJ data augmentations)
 
-     size1 = image.shape[-2:]
 
-     size2 = paste_image.shape[-2:]
 
-     if size1 != size2:
 
-         paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation)
 
-         paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST)
 
-         # resize bboxes:
 
-         ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device)
 
-         paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape)
 
-     paste_alpha_mask = paste_masks.sum(dim=0) > 0
 
-     if blending:
 
-         paste_alpha_mask = F.gaussian_blur(
 
-             paste_alpha_mask.unsqueeze(0),
 
-             kernel_size=(5, 5),
 
-             sigma=[
 
-                 2.0,
 
-             ],
 
-         )
 
-     # Copy-paste images:
 
-     image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
 
-     # Copy-paste masks:
 
-     masks = masks * (~paste_alpha_mask)
 
-     non_all_zero_masks = masks.sum((-1, -2)) > 0
 
-     masks = masks[non_all_zero_masks]
 
-     # Do a shallow copy of the target dict
 
-     out_target = {k: v for k, v in target.items()}
 
-     out_target["masks"] = torch.cat([masks, paste_masks])
 
-     # Copy-paste boxes and labels
 
-     boxes = ops.masks_to_boxes(masks)
 
-     out_target["boxes"] = torch.cat([boxes, paste_boxes])
 
-     labels = target["labels"][non_all_zero_masks]
 
-     out_target["labels"] = torch.cat([labels, paste_labels])
 
-     # Update additional optional keys: area and iscrowd if exist
 
-     if "area" in target:
 
-         out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32)
 
-     if "iscrowd" in target and "iscrowd" in paste_target:
 
-         # target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
 
-         # For example, if previous transforms geometrically modifies masks/boxes/labels but
 
-         # does not update "iscrowd"
 
-         if len(target["iscrowd"]) == len(non_all_zero_masks):
 
-             iscrowd = target["iscrowd"][non_all_zero_masks]
 
-             paste_iscrowd = paste_target["iscrowd"][random_selection]
 
-             out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])
 
-     # Check for degenerated boxes and remove them
 
-     boxes = out_target["boxes"]
 
-     degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
 
-     if degenerate_boxes.any():
 
-         valid_targets = ~degenerate_boxes.any(dim=1)
 
-         out_target["boxes"] = boxes[valid_targets]
 
-         out_target["masks"] = out_target["masks"][valid_targets]
 
-         out_target["labels"] = out_target["labels"][valid_targets]
 
-         if "area" in out_target:
 
-             out_target["area"] = out_target["area"][valid_targets]
 
-         if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets):
 
-             out_target["iscrowd"] = out_target["iscrowd"][valid_targets]
 
-     return image, out_target
 
- class SimpleCopyPaste(torch.nn.Module):
 
-     def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR):
 
-         super().__init__()
 
-         self.resize_interpolation = resize_interpolation
 
-         self.blending = blending
 
-     def forward(
 
-         self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]]
 
-     ) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]:
 
-         torch._assert(
 
-             isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]),
 
-             "images should be a list of tensors",
 
-         )
 
-         torch._assert(
 
-             isinstance(targets, (list, tuple)) and len(images) == len(targets),
 
-             "targets should be a list of the same size as images",
 
-         )
 
-         for target in targets:
 
-             # Can not check for instance type dict with inside torch.jit.script
 
-             # torch._assert(isinstance(target, dict), "targets item should be a dict")
 
-             for k in ["masks", "boxes", "labels"]:
 
-                 torch._assert(k in target, f"Key {k} should be present in targets")
 
-                 torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor")
 
-         # images = [t1, t2, ..., tN]
 
-         # Let's define paste_images as shifted list of input images
 
-         # paste_images = [t2, t3, ..., tN, t1]
 
-         # FYI: in TF they mix data on the dataset level
 
-         images_rolled = images[-1:] + images[:-1]
 
-         targets_rolled = targets[-1:] + targets[:-1]
 
-         output_images: List[torch.Tensor] = []
 
-         output_targets: List[Dict[str, Tensor]] = []
 
-         for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
 
-             output_image, output_data = _copy_paste(
 
-                 image,
 
-                 target,
 
-                 paste_image,
 
-                 paste_target,
 
-                 blending=self.blending,
 
-                 resize_interpolation=self.resize_interpolation,
 
-             )
 
-             output_images.append(output_image)
 
-             output_targets.append(output_data)
 
-         return output_images, output_targets
 
-     def __repr__(self) -> str:
 
-         s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
 
-         return s
 
 
  |