predict.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import torch
  3. from PIL import Image
  4. from ultralytics.models.yolo.segment import SegmentationPredictor
  5. from ultralytics.utils import DEFAULT_CFG, checks
  6. from ultralytics.utils.metrics import box_iou
  7. from ultralytics.utils.ops import scale_masks
  8. from .utils import adjust_bboxes_to_image_border
  9. class FastSAMPredictor(SegmentationPredictor):
  10. """
  11. FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
  12. YOLO framework.
  13. This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
  14. adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single-
  15. class segmentation.
  16. """
  17. def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
  18. """Initializes a FastSAMPredictor for fast SAM segmentation tasks in Ultralytics YOLO framework."""
  19. super().__init__(cfg, overrides, _callbacks)
  20. self.prompts = {}
  21. def postprocess(self, preds, img, orig_imgs):
  22. """Applies box postprocess for FastSAM predictions."""
  23. bboxes = self.prompts.pop("bboxes", None)
  24. points = self.prompts.pop("points", None)
  25. labels = self.prompts.pop("labels", None)
  26. texts = self.prompts.pop("texts", None)
  27. results = super().postprocess(preds, img, orig_imgs)
  28. for result in results:
  29. full_box = torch.tensor(
  30. [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
  31. )
  32. boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
  33. idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
  34. if idx.numel() != 0:
  35. result.boxes.xyxy[idx] = full_box
  36. return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
  37. def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
  38. """
  39. Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
  40. Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
  41. Args:
  42. results (Results | List[Results]): The original inference results from FastSAM models without any prompts.
  43. bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
  44. points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
  45. labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
  46. texts (str | List[str], optional): Textual prompts, a list contains string objects.
  47. Returns:
  48. (List[Results]): The output results determined by prompts.
  49. """
  50. if bboxes is None and points is None and texts is None:
  51. return results
  52. prompt_results = []
  53. if not isinstance(results, list):
  54. results = [results]
  55. for result in results:
  56. if len(result) == 0:
  57. prompt_results.append(result)
  58. continue
  59. masks = result.masks.data
  60. if masks.shape[1:] != result.orig_shape:
  61. masks = scale_masks(masks[None], result.orig_shape)[0]
  62. # bboxes prompt
  63. idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
  64. if bboxes is not None:
  65. bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
  66. bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
  67. bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
  68. mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
  69. full_mask_areas = torch.sum(masks, dim=(1, 2))
  70. union = bbox_areas[:, None] + full_mask_areas - mask_areas
  71. idx[torch.argmax(mask_areas / union, dim=1)] = True
  72. if points is not None:
  73. points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
  74. points = points[None] if points.ndim == 1 else points
  75. if labels is None:
  76. labels = torch.ones(points.shape[0])
  77. labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
  78. assert len(labels) == len(points), (
  79. f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
  80. )
  81. point_idx = (
  82. torch.ones(len(result), dtype=torch.bool, device=self.device)
  83. if labels.sum() == 0 # all negative points
  84. else torch.zeros(len(result), dtype=torch.bool, device=self.device)
  85. )
  86. for point, label in zip(points, labels):
  87. point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
  88. idx |= point_idx
  89. if texts is not None:
  90. if isinstance(texts, str):
  91. texts = [texts]
  92. crop_ims, filter_idx = [], []
  93. for i, b in enumerate(result.boxes.xyxy.tolist()):
  94. x1, y1, x2, y2 = (int(x) for x in b)
  95. if masks[i].sum() <= 100:
  96. filter_idx.append(i)
  97. continue
  98. crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
  99. similarity = self._clip_inference(crop_ims, texts)
  100. text_idx = torch.argmax(similarity, dim=-1) # (M, )
  101. if len(filter_idx):
  102. text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
  103. idx[text_idx] = True
  104. prompt_results.append(result[idx])
  105. return prompt_results
  106. def _clip_inference(self, images, texts):
  107. """
  108. CLIP Inference process.
  109. Args:
  110. images (List[PIL.Image]): A list of source images and each of them should be PIL.Image type with RGB channel order.
  111. texts (List[str]): A list of prompt texts and each of them should be string object.
  112. Returns:
  113. (torch.Tensor): The similarity between given images and texts.
  114. """
  115. try:
  116. import clip
  117. except ImportError:
  118. checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
  119. import clip
  120. if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
  121. self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
  122. images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
  123. tokenized_text = clip.tokenize(texts).to(self.device)
  124. image_features = self.clip_model.encode_image(images)
  125. text_features = self.clip_model.encode_text(tokenized_text)
  126. image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
  127. text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
  128. return (image_features * text_features[:, None]).sum(-1) # (M, N)
  129. def set_prompts(self, prompts):
  130. """Set prompts in advance."""
  131. self.prompts = prompts