val.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from pathlib import Path
  3. import torch
  4. from ultralytics.models.yolo.detect import DetectionValidator
  5. from ultralytics.utils import LOGGER, ops
  6. from ultralytics.utils.metrics import OBBMetrics, batch_probiou
  7. from ultralytics.utils.plotting import output_to_rotated_target, plot_images
  8. class OBBValidator(DetectionValidator):
  9. """
  10. A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
  11. Example:
  12. ```python
  13. from ultralytics.models.yolo.obb import OBBValidator
  14. args = dict(model="yolov8n-obb.pt", data="dota8.yaml")
  15. validator = OBBValidator(args=args)
  16. validator(model=args["model"])
  17. ```
  18. """
  19. def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
  20. """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
  21. super().__init__(dataloader, save_dir, pbar, args, _callbacks)
  22. self.args.task = "obb"
  23. self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
  24. def init_metrics(self, model):
  25. """Initialize evaluation metrics for YOLO."""
  26. super().init_metrics(model)
  27. val = self.data.get(self.args.split, "") # validation path
  28. self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
  29. def postprocess(self, preds):
  30. """Apply Non-maximum suppression to prediction outputs."""
  31. return ops.non_max_suppression(
  32. preds,
  33. self.args.conf,
  34. self.args.iou,
  35. labels=self.lb,
  36. nc=self.nc,
  37. multi_label=True,
  38. agnostic=self.args.single_cls or self.args.agnostic_nms,
  39. max_det=self.args.max_det,
  40. rotated=True,
  41. )
  42. def _process_batch(self, detections, gt_bboxes, gt_cls):
  43. """
  44. Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
  45. Args:
  46. detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated
  47. data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle).
  48. gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is
  49. represented as (x1, y1, x2, y2, angle).
  50. gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes.
  51. Returns:
  52. (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
  53. Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.
  54. Example:
  55. ```python
  56. detections = torch.rand(100, 7) # 100 sample detections
  57. gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
  58. gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
  59. correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
  60. ```
  61. Note:
  62. This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
  63. """
  64. iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
  65. return self.match_predictions(detections[:, 5], gt_cls, iou)
  66. def _prepare_batch(self, si, batch):
  67. """Prepares and returns a batch for OBB validation."""
  68. idx = batch["batch_idx"] == si
  69. cls = batch["cls"][idx].squeeze(-1)
  70. bbox = batch["bboxes"][idx]
  71. ori_shape = batch["ori_shape"][si]
  72. imgsz = batch["img"].shape[2:]
  73. ratio_pad = batch["ratio_pad"][si]
  74. if len(cls):
  75. bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
  76. ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
  77. return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
  78. def _prepare_pred(self, pred, pbatch):
  79. """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
  80. predn = pred.clone()
  81. ops.scale_boxes(
  82. pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
  83. ) # native-space pred
  84. return predn
  85. def plot_predictions(self, batch, preds, ni):
  86. """Plots predicted bounding boxes on input images and saves the result."""
  87. plot_images(
  88. batch["img"],
  89. *output_to_rotated_target(preds, max_det=self.args.max_det),
  90. paths=batch["im_file"],
  91. fname=self.save_dir / f"val_batch{ni}_pred.jpg",
  92. names=self.names,
  93. on_plot=self.on_plot,
  94. ) # pred
  95. def pred_to_json(self, predn, filename):
  96. """Serialize YOLO predictions to COCO json format."""
  97. stem = Path(filename).stem
  98. image_id = int(stem) if stem.isnumeric() else stem
  99. rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
  100. poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
  101. for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
  102. self.jdict.append(
  103. {
  104. "image_id": image_id,
  105. "category_id": self.class_map[int(predn[i, 5].item())],
  106. "score": round(predn[i, 4].item(), 5),
  107. "rbox": [round(x, 3) for x in r],
  108. "poly": [round(x, 3) for x in b],
  109. }
  110. )
  111. def save_one_txt(self, predn, save_conf, shape, file):
  112. """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
  113. import numpy as np
  114. from ultralytics.engine.results import Results
  115. rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
  116. # xywh, r, conf, cls
  117. obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
  118. Results(
  119. np.zeros((shape[0], shape[1]), dtype=np.uint8),
  120. path=None,
  121. names=self.names,
  122. obb=obb,
  123. ).save_txt(file, save_conf=save_conf)
  124. def eval_json(self, stats):
  125. """Evaluates YOLO output in JSON format and returns performance statistics."""
  126. if self.args.save_json and self.is_dota and len(self.jdict):
  127. import json
  128. import re
  129. from collections import defaultdict
  130. pred_json = self.save_dir / "predictions.json" # predictions
  131. pred_txt = self.save_dir / "predictions_txt" # predictions
  132. pred_txt.mkdir(parents=True, exist_ok=True)
  133. data = json.load(open(pred_json))
  134. # Save split results
  135. LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
  136. for d in data:
  137. image_id = d["image_id"]
  138. score = d["score"]
  139. classname = self.names[d["category_id"] - 1].replace(" ", "-")
  140. p = d["poly"]
  141. with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a") as f:
  142. f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
  143. # Save merged results, this could result slightly lower map than using official merging script,
  144. # because of the probiou calculation.
  145. pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
  146. pred_merged_txt.mkdir(parents=True, exist_ok=True)
  147. merged_results = defaultdict(list)
  148. LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
  149. for d in data:
  150. image_id = d["image_id"].split("__")[0]
  151. pattern = re.compile(r"\d+___\d+")
  152. x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
  153. bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
  154. bbox[0] += x
  155. bbox[1] += y
  156. bbox.extend([score, cls])
  157. merged_results[image_id].append(bbox)
  158. for image_id, bbox in merged_results.items():
  159. bbox = torch.tensor(bbox)
  160. max_wh = torch.max(bbox[:, :2]).item() * 2
  161. c = bbox[:, 6:7] * max_wh # classes
  162. scores = bbox[:, 5] # scores
  163. b = bbox[:, :5].clone()
  164. b[:, :2] += c
  165. # 0.3 could get results close to the ones from official merging script, even slightly better.
  166. i = ops.nms_rotated(b, scores, 0.3)
  167. bbox = bbox[i]
  168. b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
  169. for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
  170. classname = self.names[int(x[-1])].replace(" ", "-")
  171. p = [round(i, 3) for i in x[:-2]] # poly
  172. score = round(x[-2], 3)
  173. with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a") as f:
  174. f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
  175. return stats