plotting.py 62 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import math
  3. import warnings
  4. from pathlib import Path
  5. from typing import Callable, Dict, List, Optional, Union
  6. import cv2
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import torch
  10. from PIL import Image, ImageDraw, ImageFont
  11. from PIL import __version__ as pil_version
  12. from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded
  13. from ultralytics.utils.checks import check_font, check_version, is_ascii
  14. from ultralytics.utils.files import increment_path
  15. class Colors:
  16. """
  17. Ultralytics color palette https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors.
  18. This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
  19. RGB values.
  20. Attributes:
  21. palette (list of tuple): List of RGB color values.
  22. n (int): The number of colors in the palette.
  23. pose_palette (np.ndarray): A specific color palette array with dtype np.uint8.
  24. ## Ultralytics Color Palette
  25. | Index | Color | HEX | RGB |
  26. |-------|-------------------------------------------------------------------|-----------|-------------------|
  27. | 0 | <i class="fa-solid fa-square fa-2xl" style="color: #042aff;"></i> | `#042aff` | (4, 42, 255) |
  28. | 1 | <i class="fa-solid fa-square fa-2xl" style="color: #0bdbeb;"></i> | `#0bdbeb` | (11, 219, 235) |
  29. | 2 | <i class="fa-solid fa-square fa-2xl" style="color: #f3f3f3;"></i> | `#f3f3f3` | (243, 243, 243) |
  30. | 3 | <i class="fa-solid fa-square fa-2xl" style="color: #00dfb7;"></i> | `#00dfb7` | (0, 223, 183) |
  31. | 4 | <i class="fa-solid fa-square fa-2xl" style="color: #111f68;"></i> | `#111f68` | (17, 31, 104) |
  32. | 5 | <i class="fa-solid fa-square fa-2xl" style="color: #ff6fdd;"></i> | `#ff6fdd` | (255, 111, 221) |
  33. | 6 | <i class="fa-solid fa-square fa-2xl" style="color: #ff444f;"></i> | `#ff444f` | (255, 68, 79) |
  34. | 7 | <i class="fa-solid fa-square fa-2xl" style="color: #cced00;"></i> | `#cced00` | (204, 237, 0) |
  35. | 8 | <i class="fa-solid fa-square fa-2xl" style="color: #00f344;"></i> | `#00f344` | (0, 243, 68) |
  36. | 9 | <i class="fa-solid fa-square fa-2xl" style="color: #bd00ff;"></i> | `#bd00ff` | (189, 0, 255) |
  37. | 10 | <i class="fa-solid fa-square fa-2xl" style="color: #00b4ff;"></i> | `#00b4ff` | (0, 180, 255) |
  38. | 11 | <i class="fa-solid fa-square fa-2xl" style="color: #dd00ba;"></i> | `#dd00ba` | (221, 0, 186) |
  39. | 12 | <i class="fa-solid fa-square fa-2xl" style="color: #00ffff;"></i> | `#00ffff` | (0, 255, 255) |
  40. | 13 | <i class="fa-solid fa-square fa-2xl" style="color: #26c000;"></i> | `#26c000` | (38, 192, 0) |
  41. | 14 | <i class="fa-solid fa-square fa-2xl" style="color: #01ffb3;"></i> | `#01ffb3` | (1, 255, 179) |
  42. | 15 | <i class="fa-solid fa-square fa-2xl" style="color: #7d24ff;"></i> | `#7d24ff` | (125, 36, 255) |
  43. | 16 | <i class="fa-solid fa-square fa-2xl" style="color: #7b0068;"></i> | `#7b0068` | (123, 0, 104) |
  44. | 17 | <i class="fa-solid fa-square fa-2xl" style="color: #ff1b6c;"></i> | `#ff1b6c` | (255, 27, 108) |
  45. | 18 | <i class="fa-solid fa-square fa-2xl" style="color: #fc6d2f;"></i> | `#fc6d2f` | (252, 109, 47) |
  46. | 19 | <i class="fa-solid fa-square fa-2xl" style="color: #a2ff0b;"></i> | `#a2ff0b` | (162, 255, 11) |
  47. ## Pose Color Palette
  48. | Index | Color | HEX | RGB |
  49. |-------|-------------------------------------------------------------------|-----------|-------------------|
  50. | 0 | <i class="fa-solid fa-square fa-2xl" style="color: #ff8000;"></i> | `#ff8000` | (255, 128, 0) |
  51. | 1 | <i class="fa-solid fa-square fa-2xl" style="color: #ff9933;"></i> | `#ff9933` | (255, 153, 51) |
  52. | 2 | <i class="fa-solid fa-square fa-2xl" style="color: #ffb266;"></i> | `#ffb266` | (255, 178, 102) |
  53. | 3 | <i class="fa-solid fa-square fa-2xl" style="color: #e6e600;"></i> | `#e6e600` | (230, 230, 0) |
  54. | 4 | <i class="fa-solid fa-square fa-2xl" style="color: #ff99ff;"></i> | `#ff99ff` | (255, 153, 255) |
  55. | 5 | <i class="fa-solid fa-square fa-2xl" style="color: #99ccff;"></i> | `#99ccff` | (153, 204, 255) |
  56. | 6 | <i class="fa-solid fa-square fa-2xl" style="color: #ff66ff;"></i> | `#ff66ff` | (255, 102, 255) |
  57. | 7 | <i class="fa-solid fa-square fa-2xl" style="color: #ff33ff;"></i> | `#ff33ff` | (255, 51, 255) |
  58. | 8 | <i class="fa-solid fa-square fa-2xl" style="color: #66b2ff;"></i> | `#66b2ff` | (102, 178, 255) |
  59. | 9 | <i class="fa-solid fa-square fa-2xl" style="color: #3399ff;"></i> | `#3399ff` | (51, 153, 255) |
  60. | 10 | <i class="fa-solid fa-square fa-2xl" style="color: #ff9999;"></i> | `#ff9999` | (255, 153, 153) |
  61. | 11 | <i class="fa-solid fa-square fa-2xl" style="color: #ff6666;"></i> | `#ff6666` | (255, 102, 102) |
  62. | 12 | <i class="fa-solid fa-square fa-2xl" style="color: #ff3333;"></i> | `#ff3333` | (255, 51, 51) |
  63. | 13 | <i class="fa-solid fa-square fa-2xl" style="color: #99ff99;"></i> | `#99ff99` | (153, 255, 153) |
  64. | 14 | <i class="fa-solid fa-square fa-2xl" style="color: #66ff66;"></i> | `#66ff66` | (102, 255, 102) |
  65. | 15 | <i class="fa-solid fa-square fa-2xl" style="color: #33ff33;"></i> | `#33ff33` | (51, 255, 51) |
  66. | 16 | <i class="fa-solid fa-square fa-2xl" style="color: #00ff00;"></i> | `#00ff00` | (0, 255, 0) |
  67. | 17 | <i class="fa-solid fa-square fa-2xl" style="color: #0000ff;"></i> | `#0000ff` | (0, 0, 255) |
  68. | 18 | <i class="fa-solid fa-square fa-2xl" style="color: #ff0000;"></i> | `#ff0000` | (255, 0, 0) |
  69. | 19 | <i class="fa-solid fa-square fa-2xl" style="color: #ffffff;"></i> | `#ffffff` | (255, 255, 255) |
  70. !!! note "Ultralytics Brand Colors"
  71. For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand). Please use the official Ultralytics colors for all marketing materials.
  72. """
  73. def __init__(self):
  74. """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
  75. hexs = (
  76. "042AFF",
  77. "0BDBEB",
  78. "F3F3F3",
  79. "00DFB7",
  80. "111F68",
  81. "FF6FDD",
  82. "FF444F",
  83. "CCED00",
  84. "00F344",
  85. "BD00FF",
  86. "00B4FF",
  87. "DD00BA",
  88. "00FFFF",
  89. "26C000",
  90. "01FFB3",
  91. "7D24FF",
  92. "7B0068",
  93. "FF1B6C",
  94. "FC6D2F",
  95. "A2FF0B",
  96. )
  97. self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
  98. self.n = len(self.palette)
  99. self.pose_palette = np.array(
  100. [
  101. [255, 128, 0],
  102. [255, 153, 51],
  103. [255, 178, 102],
  104. [230, 230, 0],
  105. [255, 153, 255],
  106. [153, 204, 255],
  107. [255, 102, 255],
  108. [255, 51, 255],
  109. [102, 178, 255],
  110. [51, 153, 255],
  111. [255, 153, 153],
  112. [255, 102, 102],
  113. [255, 51, 51],
  114. [153, 255, 153],
  115. [102, 255, 102],
  116. [51, 255, 51],
  117. [0, 255, 0],
  118. [0, 0, 255],
  119. [255, 0, 0],
  120. [255, 255, 255],
  121. ],
  122. dtype=np.uint8,
  123. )
  124. def __call__(self, i, bgr=False):
  125. """Converts hex color codes to RGB values."""
  126. c = self.palette[int(i) % self.n]
  127. return (c[2], c[1], c[0]) if bgr else c
  128. @staticmethod
  129. def hex2rgb(h):
  130. """Converts hex color codes to RGB values (i.e. default PIL order)."""
  131. return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
  132. colors = Colors() # create instance for 'from utils.plots import colors'
  133. class Annotator:
  134. """
  135. Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
  136. Attributes:
  137. im (Image.Image or numpy array): The image to annotate.
  138. pil (bool): Whether to use PIL or cv2 for drawing annotations.
  139. font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.
  140. lw (float): Line width for drawing.
  141. skeleton (List[List[int]]): Skeleton structure for keypoints.
  142. limb_color (List[int]): Color palette for limbs.
  143. kpt_color (List[int]): Color palette for keypoints.
  144. """
  145. def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
  146. """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
  147. non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
  148. input_is_pil = isinstance(im, Image.Image)
  149. self.pil = pil or non_ascii or input_is_pil
  150. self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
  151. if self.pil: # use PIL
  152. self.im = im if input_is_pil else Image.fromarray(im)
  153. self.draw = ImageDraw.Draw(self.im)
  154. try:
  155. font = check_font("Arial.Unicode.ttf" if non_ascii else font)
  156. size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
  157. self.font = ImageFont.truetype(str(font), size)
  158. except Exception:
  159. self.font = ImageFont.load_default()
  160. # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
  161. if check_version(pil_version, "9.2.0"):
  162. self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
  163. else: # use cv2
  164. assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."
  165. self.im = im if im.flags.writeable else im.copy()
  166. self.tf = max(self.lw - 1, 1) # font thickness
  167. self.sf = self.lw / 3 # font scale
  168. # Pose
  169. self.skeleton = [
  170. [16, 14],
  171. [14, 12],
  172. [17, 15],
  173. [15, 13],
  174. [12, 13],
  175. [6, 12],
  176. [7, 13],
  177. [6, 7],
  178. [6, 8],
  179. [7, 9],
  180. [8, 10],
  181. [9, 11],
  182. [2, 3],
  183. [1, 2],
  184. [1, 3],
  185. [2, 4],
  186. [3, 5],
  187. [4, 6],
  188. [5, 7],
  189. ]
  190. self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
  191. self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
  192. self.dark_colors = {
  193. (235, 219, 11),
  194. (243, 243, 243),
  195. (183, 223, 0),
  196. (221, 111, 255),
  197. (0, 237, 204),
  198. (68, 243, 0),
  199. (255, 255, 0),
  200. (179, 255, 1),
  201. (11, 255, 162),
  202. }
  203. self.light_colors = {
  204. (255, 42, 4),
  205. (79, 68, 255),
  206. (255, 0, 189),
  207. (255, 180, 0),
  208. (186, 0, 221),
  209. (0, 192, 38),
  210. (255, 36, 125),
  211. (104, 0, 123),
  212. (108, 27, 255),
  213. (47, 109, 252),
  214. (104, 31, 17),
  215. }
  216. def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
  217. """
  218. Assign text color based on background color.
  219. Args:
  220. color (tuple, optional): The background color of the rectangle for text (B, G, R).
  221. txt_color (tuple, optional): The color of the text (R, G, B).
  222. Returns:
  223. txt_color (tuple): Text color for label
  224. """
  225. if color in self.dark_colors:
  226. return 104, 31, 17
  227. elif color in self.light_colors:
  228. return 255, 255, 255
  229. else:
  230. return txt_color
  231. def circle_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=2):
  232. """
  233. Draws a label with a background circle centered within a given bounding box.
  234. Args:
  235. box (tuple): The bounding box coordinates (x1, y1, x2, y2).
  236. label (str): The text label to be displayed.
  237. color (tuple, optional): The background color of the rectangle (B, G, R).
  238. txt_color (tuple, optional): The color of the text (R, G, B).
  239. margin (int, optional): The margin between the text and the rectangle border.
  240. """
  241. # If label have more than 3 characters, skip other characters, due to circle size
  242. if len(label) > 3:
  243. print(
  244. f"Length of label is {len(label)}, initial 3 label characters will be considered for circle annotation!"
  245. )
  246. label = label[:3]
  247. # Calculate the center of the box
  248. x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
  249. # Get the text size
  250. text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]
  251. # Calculate the required radius to fit the text with the margin
  252. required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin
  253. # Draw the circle with the required radius
  254. cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)
  255. # Calculate the position for the text
  256. text_x = x_center - text_size[0] // 2
  257. text_y = y_center + text_size[1] // 2
  258. # Draw the text
  259. cv2.putText(
  260. self.im,
  261. str(label),
  262. (text_x, text_y),
  263. cv2.FONT_HERSHEY_SIMPLEX,
  264. self.sf - 0.15,
  265. self.get_txt_color(color, txt_color),
  266. self.tf,
  267. lineType=cv2.LINE_AA,
  268. )
  269. def text_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=5):
  270. """
  271. Draws a label with a background rectangle centered within a given bounding box.
  272. Args:
  273. box (tuple): The bounding box coordinates (x1, y1, x2, y2).
  274. label (str): The text label to be displayed.
  275. color (tuple, optional): The background color of the rectangle (B, G, R).
  276. txt_color (tuple, optional): The color of the text (R, G, B).
  277. margin (int, optional): The margin between the text and the rectangle border.
  278. """
  279. # Calculate the center of the bounding box
  280. x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
  281. # Get the size of the text
  282. text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]
  283. # Calculate the top-left corner of the text (to center it)
  284. text_x = x_center - text_size[0] // 2
  285. text_y = y_center + text_size[1] // 2
  286. # Calculate the coordinates of the background rectangle
  287. rect_x1 = text_x - margin
  288. rect_y1 = text_y - text_size[1] - margin
  289. rect_x2 = text_x + text_size[0] + margin
  290. rect_y2 = text_y + margin
  291. # Draw the background rectangle
  292. cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)
  293. # Draw the text on top of the rectangle
  294. cv2.putText(
  295. self.im,
  296. label,
  297. (text_x, text_y),
  298. cv2.FONT_HERSHEY_SIMPLEX,
  299. self.sf - 0.1,
  300. self.get_txt_color(color, txt_color),
  301. self.tf,
  302. lineType=cv2.LINE_AA,
  303. )
  304. def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
  305. """
  306. Draws a bounding box to image with label.
  307. Args:
  308. box (tuple): The bounding box coordinates (x1, y1, x2, y2).
  309. label (str): The text label to be displayed.
  310. color (tuple, optional): The background color of the rectangle (B, G, R).
  311. txt_color (tuple, optional): The color of the text (R, G, B).
  312. rotated (bool, optional): Variable used to check if task is OBB
  313. """
  314. txt_color = self.get_txt_color(color, txt_color)
  315. if isinstance(box, torch.Tensor):
  316. box = box.tolist()
  317. if self.pil or not is_ascii(label):
  318. if rotated:
  319. p1 = box[0]
  320. self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) # PIL requires tuple box
  321. else:
  322. p1 = (box[0], box[1])
  323. self.draw.rectangle(box, width=self.lw, outline=color) # box
  324. if label:
  325. w, h = self.font.getsize(label) # text width, height
  326. outside = p1[1] >= h # label fits outside box
  327. if p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of image
  328. p1 = self.im.size[0] - w, p1[1]
  329. self.draw.rectangle(
  330. (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),
  331. fill=color,
  332. )
  333. # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
  334. self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
  335. else: # cv2
  336. if rotated:
  337. p1 = [int(b) for b in box[0]]
  338. cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw) # cv2 requires nparray box
  339. else:
  340. p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
  341. cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
  342. if label:
  343. w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
  344. h += 3 # add pixels to pad text
  345. outside = p1[1] >= h # label fits outside box
  346. if p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of image
  347. p1 = self.im.shape[1] - w, p1[1]
  348. p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h
  349. cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
  350. cv2.putText(
  351. self.im,
  352. label,
  353. (p1[0], p1[1] - 2 if outside else p1[1] + h - 1),
  354. 0,
  355. self.sf,
  356. txt_color,
  357. thickness=self.tf,
  358. lineType=cv2.LINE_AA,
  359. )
  360. def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
  361. """
  362. Plot masks on image.
  363. Args:
  364. masks (tensor): Predicted masks on cuda, shape: [n, h, w]
  365. colors (List[List[Int]]): Colors for predicted masks, [[r, g, b] * n]
  366. im_gpu (tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]
  367. alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque
  368. retina_masks (bool): Whether to use high resolution masks or not. Defaults to False.
  369. """
  370. if self.pil:
  371. # Convert to numpy first
  372. self.im = np.asarray(self.im).copy()
  373. if len(masks) == 0:
  374. self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
  375. if im_gpu.device != masks.device:
  376. im_gpu = im_gpu.to(masks.device)
  377. colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
  378. colors = colors[:, None, None] # shape(n,1,1,3)
  379. masks = masks.unsqueeze(3) # shape(n,h,w,1)
  380. masks_color = masks * (colors * alpha) # shape(n,h,w,3)
  381. inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
  382. mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
  383. im_gpu = im_gpu.flip(dims=[0]) # flip channel
  384. im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
  385. im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
  386. im_mask = im_gpu * 255
  387. im_mask_np = im_mask.byte().cpu().numpy()
  388. self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
  389. if self.pil:
  390. # Convert im back to PIL and update draw
  391. self.fromarray(self.im)
  392. def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None):
  393. """
  394. Plot keypoints on the image.
  395. Args:
  396. kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
  397. shape (tuple, optional): Image shape (h, w). Defaults to (640, 640).
  398. radius (int, optional): Keypoint radius. Defaults to 5.
  399. kpt_line (bool, optional): Draw lines between keypoints. Defaults to True.
  400. conf_thres (float, optional): Confidence threshold. Defaults to 0.25.
  401. kpt_color (tuple, optional): Keypoint color (B, G, R). Defaults to None.
  402. Note:
  403. - `kpt_line=True` currently only supports human pose plotting.
  404. - Modifies self.im in-place.
  405. - If self.pil is True, converts image to numpy array and back to PIL.
  406. """
  407. radius = radius if radius is not None else self.lw
  408. if self.pil:
  409. # Convert to numpy first
  410. self.im = np.asarray(self.im).copy()
  411. nkpt, ndim = kpts.shape
  412. is_pose = nkpt == 17 and ndim in {2, 3}
  413. kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
  414. for i, k in enumerate(kpts):
  415. color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i))
  416. x_coord, y_coord = k[0], k[1]
  417. if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
  418. if len(k) == 3:
  419. conf = k[2]
  420. if conf < conf_thres:
  421. continue
  422. cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)
  423. if kpt_line:
  424. ndim = kpts.shape[-1]
  425. for i, sk in enumerate(self.skeleton):
  426. pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))
  427. pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))
  428. if ndim == 3:
  429. conf1 = kpts[(sk[0] - 1), 2]
  430. conf2 = kpts[(sk[1] - 1), 2]
  431. if conf1 < conf_thres or conf2 < conf_thres:
  432. continue
  433. if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
  434. continue
  435. if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
  436. continue
  437. cv2.line(
  438. self.im,
  439. pos1,
  440. pos2,
  441. kpt_color or self.limb_color[i].tolist(),
  442. thickness=int(np.ceil(self.lw / 2)),
  443. lineType=cv2.LINE_AA,
  444. )
  445. if self.pil:
  446. # Convert im back to PIL and update draw
  447. self.fromarray(self.im)
  448. def rectangle(self, xy, fill=None, outline=None, width=1):
  449. """Add rectangle to image (PIL-only)."""
  450. self.draw.rectangle(xy, fill, outline, width)
  451. def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):
  452. """Adds text to an image using PIL or cv2."""
  453. if anchor == "bottom": # start y from font bottom
  454. w, h = self.font.getsize(text) # text width, height
  455. xy[1] += 1 - h
  456. if self.pil:
  457. if box_style:
  458. w, h = self.font.getsize(text)
  459. self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
  460. # Using `txt_color` for background and draw fg with white color
  461. txt_color = (255, 255, 255)
  462. if "\n" in text:
  463. lines = text.split("\n")
  464. _, h = self.font.getsize(text)
  465. for line in lines:
  466. self.draw.text(xy, line, fill=txt_color, font=self.font)
  467. xy[1] += h
  468. else:
  469. self.draw.text(xy, text, fill=txt_color, font=self.font)
  470. else:
  471. if box_style:
  472. w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
  473. h += 3 # add pixels to pad text
  474. outside = xy[1] >= h # label fits outside box
  475. p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h
  476. cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled
  477. # Using `txt_color` for background and draw fg with white color
  478. txt_color = (255, 255, 255)
  479. cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
  480. def fromarray(self, im):
  481. """Update self.im from a numpy array."""
  482. self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
  483. self.draw = ImageDraw.Draw(self.im)
  484. def result(self):
  485. """Return annotated image as array."""
  486. return np.asarray(self.im)
  487. def show(self, title=None):
  488. """Show the annotated image."""
  489. im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
  490. if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
  491. try:
  492. display(im) # noqa - display() function only available in ipython environments
  493. except ImportError as e:
  494. LOGGER.warning(f"Unable to display image in Jupyter notebooks: {e}")
  495. else:
  496. im.show(title=title)
  497. def save(self, filename="image.jpg"):
  498. """Save the annotated image to 'filename'."""
  499. cv2.imwrite(filename, np.asarray(self.im))
  500. @staticmethod
  501. def get_bbox_dimension(bbox=None):
  502. """
  503. Calculate the area of a bounding box.
  504. Args:
  505. bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
  506. Returns:
  507. width (float): Width of the bounding box.
  508. height (float): Height of the bounding box.
  509. area (float): Area enclosed by the bounding box.
  510. """
  511. x_min, y_min, x_max, y_max = bbox
  512. width = x_max - x_min
  513. height = y_max - y_min
  514. return width, height, width * height
  515. def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
  516. """
  517. Draw region line.
  518. Args:
  519. reg_pts (list): Region Points (for line 2 points, for region 4 points)
  520. color (tuple): Region Color value
  521. thickness (int): Region area thickness value
  522. """
  523. cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
  524. # Draw small circles at the corner points
  525. for point in reg_pts:
  526. cv2.circle(self.im, (point[0], point[1]), thickness * 2, color, -1) # -1 fills the circle
  527. def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
  528. """
  529. Draw centroid point and track trails.
  530. Args:
  531. track (list): object tracking points for trails display
  532. color (tuple): tracks line color
  533. track_thickness (int): track line thickness value
  534. """
  535. points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
  536. cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
  537. cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
  538. def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), txt_color=(0, 0, 0)):
  539. """
  540. Displays queue counts on an image centered at the points with customizable font size and colors.
  541. Args:
  542. label (str): Queue counts label.
  543. points (tuple): Region points for center point calculation to display text.
  544. region_color (tuple): RGB queue region color.
  545. txt_color (tuple): RGB text display color.
  546. """
  547. x_values = [point[0] for point in points]
  548. y_values = [point[1] for point in points]
  549. center_x = sum(x_values) // len(points)
  550. center_y = sum(y_values) // len(points)
  551. text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]
  552. text_width = text_size[0]
  553. text_height = text_size[1]
  554. rect_width = text_width + 20
  555. rect_height = text_height + 20
  556. rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2)
  557. rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2)
  558. cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1)
  559. text_x = center_x - text_width // 2
  560. text_y = center_y + text_height // 2
  561. # Draw text
  562. cv2.putText(
  563. self.im,
  564. label,
  565. (text_x, text_y),
  566. 0,
  567. fontScale=self.sf,
  568. color=txt_color,
  569. thickness=self.tf,
  570. lineType=cv2.LINE_AA,
  571. )
  572. def display_objects_labels(self, im0, text, txt_color, bg_color, x_center, y_center, margin):
  573. """
  574. Display the bounding boxes labels in parking management app.
  575. Args:
  576. im0 (ndarray): Inference image.
  577. text (str): Object/class name.
  578. txt_color (tuple): Display color for text foreground.
  579. bg_color (tuple): Display color for text background.
  580. x_center (float): The x position center point for bounding box.
  581. y_center (float): The y position center point for bounding box.
  582. margin (int): The gap between text and rectangle for better display.
  583. """
  584. text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
  585. text_x = x_center - text_size[0] // 2
  586. text_y = y_center + text_size[1] // 2
  587. rect_x1 = text_x - margin
  588. rect_y1 = text_y - text_size[1] - margin
  589. rect_x2 = text_x + text_size[0] + margin
  590. rect_y2 = text_y + margin
  591. cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
  592. cv2.putText(im0, text, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
  593. def display_analytics(self, im0, text, txt_color, bg_color, margin):
  594. """
  595. Display the overall statistics for parking lots.
  596. Args:
  597. im0 (ndarray): Inference image.
  598. text (dict): Labels dictionary.
  599. txt_color (tuple): Display color for text foreground.
  600. bg_color (tuple): Display color for text background.
  601. margin (int): Gap between text and rectangle for better display.
  602. """
  603. horizontal_gap = int(im0.shape[1] * 0.02)
  604. vertical_gap = int(im0.shape[0] * 0.01)
  605. text_y_offset = 0
  606. for label, value in text.items():
  607. txt = f"{label}: {value}"
  608. text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0]
  609. if text_size[0] < 5 or text_size[1] < 5:
  610. text_size = (5, 5)
  611. text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap
  612. text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap
  613. rect_x1 = text_x - margin * 2
  614. rect_y1 = text_y - text_size[1] - margin * 2
  615. rect_x2 = text_x + text_size[0] + margin * 2
  616. rect_y2 = text_y + margin * 2
  617. cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
  618. cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
  619. text_y_offset = rect_y2
  620. @staticmethod
  621. def estimate_pose_angle(a, b, c):
  622. """
  623. Calculate the pose angle for object.
  624. Args:
  625. a (float) : The value of pose point a
  626. b (float): The value of pose point b
  627. c (float): The value o pose point c
  628. Returns:
  629. angle (degree): Degree value of angle between three points
  630. """
  631. a, b, c = np.array(a), np.array(b), np.array(c)
  632. radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
  633. angle = np.abs(radians * 180.0 / np.pi)
  634. if angle > 180.0:
  635. angle = 360 - angle
  636. return angle
  637. def draw_specific_points(self, keypoints, indices=None, radius=2, conf_thres=0.25):
  638. """
  639. Draw specific keypoints for gym steps counting.
  640. Args:
  641. keypoints (list): Keypoints data to be plotted.
  642. indices (list, optional): Keypoint indices to be plotted. Defaults to [2, 5, 7].
  643. radius (int, optional): Keypoint radius. Defaults to 2.
  644. conf_thres (float, optional): Confidence threshold for keypoints. Defaults to 0.25.
  645. Returns:
  646. (numpy.ndarray): Image with drawn keypoints.
  647. Note:
  648. Keypoint format: [x, y] or [x, y, confidence].
  649. Modifies self.im in-place.
  650. """
  651. indices = indices or [2, 5, 7]
  652. points = [(int(k[0]), int(k[1])) for i, k in enumerate(keypoints) if i in indices and k[2] >= conf_thres]
  653. # Draw lines between consecutive points
  654. for start, end in zip(points[:-1], points[1:]):
  655. cv2.line(self.im, start, end, (0, 255, 0), 2, lineType=cv2.LINE_AA)
  656. # Draw circles for keypoints
  657. for pt in points:
  658. cv2.circle(self.im, pt, radius, (0, 0, 255), -1, lineType=cv2.LINE_AA)
  659. return self.im
  660. def plot_workout_information(self, display_text, position, color=(104, 31, 17), txt_color=(255, 255, 255)):
  661. """
  662. Draw text with a background on the image.
  663. Args:
  664. display_text (str): The text to be displayed.
  665. position (tuple): Coordinates (x, y) on the image where the text will be placed.
  666. color (tuple, optional): Text background color
  667. txt_color (tuple, optional): Text foreground color
  668. """
  669. (text_width, text_height), _ = cv2.getTextSize(display_text, 0, self.sf, self.tf)
  670. # Draw background rectangle
  671. cv2.rectangle(
  672. self.im,
  673. (position[0], position[1] - text_height - 5),
  674. (position[0] + text_width + 10, position[1] - text_height - 5 + text_height + 10 + self.tf),
  675. color,
  676. -1,
  677. )
  678. # Draw text
  679. cv2.putText(self.im, display_text, position, 0, self.sf, txt_color, self.tf)
  680. return text_height
  681. def plot_angle_and_count_and_stage(
  682. self, angle_text, count_text, stage_text, center_kpt, color=(104, 31, 17), txt_color=(255, 255, 255)
  683. ):
  684. """
  685. Plot the pose angle, count value, and step stage.
  686. Args:
  687. angle_text (str): Angle value for workout monitoring
  688. count_text (str): Counts value for workout monitoring
  689. stage_text (str): Stage decision for workout monitoring
  690. center_kpt (list): Centroid pose index for workout monitoring
  691. color (tuple, optional): Text background color
  692. txt_color (tuple, optional): Text foreground color
  693. """
  694. # Format text
  695. angle_text, count_text, stage_text = f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}"
  696. # Draw angle, count and stage text
  697. angle_height = self.plot_workout_information(
  698. angle_text, (int(center_kpt[0]), int(center_kpt[1])), color, txt_color
  699. )
  700. count_height = self.plot_workout_information(
  701. count_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + 20), color, txt_color
  702. )
  703. self.plot_workout_information(
  704. stage_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + count_height + 40), color, txt_color
  705. )
  706. def seg_bbox(self, mask, mask_color=(255, 0, 255), label=None, txt_color=(255, 255, 255)):
  707. """
  708. Function for drawing segmented object in bounding box shape.
  709. Args:
  710. mask (np.ndarray): A 2D array of shape (N, 2) containing the contour points of the segmented object.
  711. mask_color (tuple): RGB color for the contour and label background.
  712. label (str, optional): Text label for the object. If None, no label is drawn.
  713. txt_color (tuple): RGB color for the label text.
  714. """
  715. if mask.size == 0: # no masks to plot
  716. return
  717. cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
  718. text_size, _ = cv2.getTextSize(label, 0, self.sf, self.tf)
  719. if label:
  720. cv2.rectangle(
  721. self.im,
  722. (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
  723. (int(mask[0][0]) + text_size[0] // 2 + 10, int(mask[0][1] + 10)),
  724. mask_color,
  725. -1,
  726. )
  727. cv2.putText(
  728. self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, txt_color, self.tf
  729. )
  730. def sweep_annotator(self, line_x=0, line_y=0, label=None, color=(221, 0, 186), txt_color=(255, 255, 255)):
  731. """
  732. Function for drawing a sweep annotation line and an optional label.
  733. Args:
  734. line_x (int): The x-coordinate of the sweep line.
  735. line_y (int): The y-coordinate limit of the sweep line.
  736. label (str, optional): Text label to be drawn in center of sweep line. If None, no label is drawn.
  737. color (tuple): RGB color for the line and label background.
  738. txt_color (tuple): RGB color for the label text.
  739. """
  740. # Draw the sweep line
  741. cv2.line(self.im, (line_x, 0), (line_x, line_y), color, self.tf * 2)
  742. # Draw label, if provided
  743. if label:
  744. (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf, self.tf)
  745. cv2.rectangle(
  746. self.im,
  747. (line_x - text_width // 2 - 10, line_y // 2 - text_height // 2 - 10),
  748. (line_x + text_width // 2 + 10, line_y // 2 + text_height // 2 + 10),
  749. color,
  750. -1,
  751. )
  752. cv2.putText(
  753. self.im,
  754. label,
  755. (line_x - text_width // 2, line_y // 2 + text_height // 2),
  756. cv2.FONT_HERSHEY_SIMPLEX,
  757. self.sf,
  758. txt_color,
  759. self.tf,
  760. )
  761. def plot_distance_and_line(
  762. self, pixels_distance, centroids, line_color=(104, 31, 17), centroid_color=(255, 0, 255)
  763. ):
  764. """
  765. Plot the distance and line on frame.
  766. Args:
  767. pixels_distance (float): Pixels distance between two bbox centroids.
  768. centroids (list): Bounding box centroids data.
  769. line_color (tuple, optional): Distance line color.
  770. centroid_color (tuple, optional): Bounding box centroid color.
  771. """
  772. # Get the text size
  773. text = f"Pixels Distance: {pixels_distance:.2f}"
  774. (text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf)
  775. # Define corners with 10-pixel margin and draw rectangle
  776. cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1)
  777. # Calculate the position for the text with a 10-pixel margin and draw text
  778. text_position = (25, 25 + text_height_m + 10)
  779. cv2.putText(
  780. self.im,
  781. text,
  782. text_position,
  783. 0,
  784. self.sf,
  785. (255, 255, 255),
  786. self.tf,
  787. cv2.LINE_AA,
  788. )
  789. cv2.line(self.im, centroids[0], centroids[1], line_color, 3)
  790. cv2.circle(self.im, centroids[0], 6, centroid_color, -1)
  791. cv2.circle(self.im, centroids[1], 6, centroid_color, -1)
  792. def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255)):
  793. """
  794. Function for pinpoint human-vision eye mapping and plotting.
  795. Args:
  796. box (list): Bounding box coordinates
  797. center_point (tuple): center point for vision eye view
  798. color (tuple): object centroid and line color value
  799. pin_color (tuple): visioneye point color value
  800. """
  801. center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
  802. cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)
  803. cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
  804. cv2.line(self.im, center_point, center_bbox, color, self.tf)
  805. @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
  806. @plt_settings()
  807. def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
  808. """Plot training labels including class histograms and box statistics."""
  809. import pandas # scope for faster 'import ultralytics'
  810. import seaborn # scope for faster 'import ultralytics'
  811. # Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
  812. warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
  813. warnings.filterwarnings("ignore", category=FutureWarning)
  814. # Plot dataset labels
  815. LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
  816. nc = int(cls.max() + 1) # number of classes
  817. boxes = boxes[:1000000] # limit to 1M boxes
  818. x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
  819. # Seaborn correlogram
  820. seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
  821. plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
  822. plt.close()
  823. # Matplotlib labels
  824. ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  825. y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  826. for i in range(nc):
  827. y[2].patches[i].set_color([x / 255 for x in colors(i)])
  828. ax[0].set_ylabel("instances")
  829. if 0 < len(names) < 30:
  830. ax[0].set_xticks(range(len(names)))
  831. ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
  832. else:
  833. ax[0].set_xlabel("classes")
  834. seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
  835. seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
  836. # Rectangles
  837. boxes[:, 0:2] = 0.5 # center
  838. boxes = ops.xywh2xyxy(boxes) * 1000
  839. img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
  840. for cls, box in zip(cls[:500], boxes[:500]):
  841. ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
  842. ax[1].imshow(img)
  843. ax[1].axis("off")
  844. for a in [0, 1, 2, 3]:
  845. for s in ["top", "right", "left", "bottom"]:
  846. ax[a].spines[s].set_visible(False)
  847. fname = save_dir / "labels.jpg"
  848. plt.savefig(fname, dpi=200)
  849. plt.close()
  850. if on_plot:
  851. on_plot(fname)
  852. def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
  853. """
  854. Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
  855. This function takes a bounding box and an image, and then saves a cropped portion of the image according
  856. to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
  857. adjustments to the bounding box.
  858. Args:
  859. xyxy (torch.Tensor or list): A tensor or list representing the bounding box in xyxy format.
  860. im (numpy.ndarray): The input image.
  861. file (Path, optional): The path where the cropped image will be saved. Defaults to 'im.jpg'.
  862. gain (float, optional): A multiplicative factor to increase the size of the bounding box. Defaults to 1.02.
  863. pad (int, optional): The number of pixels to add to the width and height of the bounding box. Defaults to 10.
  864. square (bool, optional): If True, the bounding box will be transformed into a square. Defaults to False.
  865. BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB. Defaults to False.
  866. save (bool, optional): If True, the cropped image will be saved to disk. Defaults to True.
  867. Returns:
  868. (numpy.ndarray): The cropped image.
  869. Example:
  870. ```python
  871. from ultralytics.utils.plotting import save_one_box
  872. xyxy = [50, 50, 150, 150]
  873. im = cv2.imread("image.jpg")
  874. cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True)
  875. ```
  876. """
  877. if not isinstance(xyxy, torch.Tensor): # may be list
  878. xyxy = torch.stack(xyxy)
  879. b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes
  880. if square:
  881. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  882. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  883. xyxy = ops.xywh2xyxy(b).long()
  884. xyxy = ops.clip_boxes(xyxy, im.shape)
  885. crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
  886. if save:
  887. file.parent.mkdir(parents=True, exist_ok=True) # make directory
  888. f = str(increment_path(file).with_suffix(".jpg"))
  889. # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
  890. Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
  891. return crop
  892. @threaded
  893. def plot_images(
  894. images: Union[torch.Tensor, np.ndarray],
  895. batch_idx: Union[torch.Tensor, np.ndarray],
  896. cls: Union[torch.Tensor, np.ndarray],
  897. bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
  898. confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
  899. masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
  900. kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
  901. paths: Optional[List[str]] = None,
  902. fname: str = "images.jpg",
  903. names: Optional[Dict[int, str]] = None,
  904. on_plot: Optional[Callable] = None,
  905. max_size: int = 1920,
  906. max_subplots: int = 16,
  907. save: bool = True,
  908. conf_thres: float = 0.25,
  909. ) -> Optional[np.ndarray]:
  910. """
  911. Plot image grid with labels, bounding boxes, masks, and keypoints.
  912. Args:
  913. images: Batch of images to plot. Shape: (batch_size, channels, height, width).
  914. batch_idx: Batch indices for each detection. Shape: (num_detections,).
  915. cls: Class labels for each detection. Shape: (num_detections,).
  916. bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.
  917. confs: Confidence scores for each detection. Shape: (num_detections,).
  918. masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).
  919. kpts: Keypoints for each detection. Shape: (num_detections, 51).
  920. paths: List of file paths for each image in the batch.
  921. fname: Output filename for the plotted image grid.
  922. names: Dictionary mapping class indices to class names.
  923. on_plot: Optional callback function to be called after saving the plot.
  924. max_size: Maximum size of the output image grid.
  925. max_subplots: Maximum number of subplots in the image grid.
  926. save: Whether to save the plotted image grid to a file.
  927. conf_thres: Confidence threshold for displaying detections.
  928. Returns:
  929. np.ndarray: Plotted image grid as a numpy array if save is False, None otherwise.
  930. Note:
  931. This function supports both tensor and numpy array inputs. It will automatically
  932. convert tensor inputs to numpy arrays for processing.
  933. """
  934. if isinstance(images, torch.Tensor):
  935. images = images.cpu().float().numpy()
  936. if isinstance(cls, torch.Tensor):
  937. cls = cls.cpu().numpy()
  938. if isinstance(bboxes, torch.Tensor):
  939. bboxes = bboxes.cpu().numpy()
  940. if isinstance(masks, torch.Tensor):
  941. masks = masks.cpu().numpy().astype(int)
  942. if isinstance(kpts, torch.Tensor):
  943. kpts = kpts.cpu().numpy()
  944. if isinstance(batch_idx, torch.Tensor):
  945. batch_idx = batch_idx.cpu().numpy()
  946. bs, _, h, w = images.shape # batch size, _, height, width
  947. bs = min(bs, max_subplots) # limit plot images
  948. ns = np.ceil(bs**0.5) # number of subplots (square)
  949. if np.max(images[0]) <= 1:
  950. images *= 255 # de-normalise (optional)
  951. # Build Image
  952. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  953. for i in range(bs):
  954. x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
  955. mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
  956. # Resize (optional)
  957. scale = max_size / ns / max(h, w)
  958. if scale < 1:
  959. h = math.ceil(scale * h)
  960. w = math.ceil(scale * w)
  961. mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
  962. # Annotate
  963. fs = int((h + w) * ns * 0.01) # font size
  964. annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
  965. for i in range(bs):
  966. x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
  967. annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
  968. if paths:
  969. annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
  970. if len(cls) > 0:
  971. idx = batch_idx == i
  972. classes = cls[idx].astype("int")
  973. labels = confs is None
  974. if len(bboxes):
  975. boxes = bboxes[idx]
  976. conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
  977. if len(boxes):
  978. if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
  979. boxes[..., [0, 2]] *= w # scale to pixels
  980. boxes[..., [1, 3]] *= h
  981. elif scale < 1: # absolute coords need scale if image scales
  982. boxes[..., :4] *= scale
  983. boxes[..., 0] += x
  984. boxes[..., 1] += y
  985. is_obb = boxes.shape[-1] == 5 # xywhr
  986. boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
  987. for j, box in enumerate(boxes.astype(np.int64).tolist()):
  988. c = classes[j]
  989. color = colors(c)
  990. c = names.get(c, c) if names else c
  991. if labels or conf[j] > conf_thres:
  992. label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
  993. annotator.box_label(box, label, color=color, rotated=is_obb)
  994. elif len(classes):
  995. for c in classes:
  996. color = colors(c)
  997. c = names.get(c, c) if names else c
  998. annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
  999. # Plot keypoints
  1000. if len(kpts):
  1001. kpts_ = kpts[idx].copy()
  1002. if len(kpts_):
  1003. if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
  1004. kpts_[..., 0] *= w # scale to pixels
  1005. kpts_[..., 1] *= h
  1006. elif scale < 1: # absolute coords need scale if image scales
  1007. kpts_ *= scale
  1008. kpts_[..., 0] += x
  1009. kpts_[..., 1] += y
  1010. for j in range(len(kpts_)):
  1011. if labels or conf[j] > conf_thres:
  1012. annotator.kpts(kpts_[j], conf_thres=conf_thres)
  1013. # Plot masks
  1014. if len(masks):
  1015. if idx.shape[0] == masks.shape[0]: # overlap_masks=False
  1016. image_masks = masks[idx]
  1017. else: # overlap_masks=True
  1018. image_masks = masks[[i]] # (1, 640, 640)
  1019. nl = idx.sum()
  1020. index = np.arange(nl).reshape((nl, 1, 1)) + 1
  1021. image_masks = np.repeat(image_masks, nl, axis=0)
  1022. image_masks = np.where(image_masks == index, 1.0, 0.0)
  1023. im = np.asarray(annotator.im).copy()
  1024. for j in range(len(image_masks)):
  1025. if labels or conf[j] > conf_thres:
  1026. color = colors(classes[j])
  1027. mh, mw = image_masks[j].shape
  1028. if mh != h or mw != w:
  1029. mask = image_masks[j].astype(np.uint8)
  1030. mask = cv2.resize(mask, (w, h))
  1031. mask = mask.astype(bool)
  1032. else:
  1033. mask = image_masks[j].astype(bool)
  1034. try:
  1035. im[y : y + h, x : x + w, :][mask] = (
  1036. im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
  1037. )
  1038. except Exception:
  1039. pass
  1040. annotator.fromarray(im)
  1041. if not save:
  1042. return np.asarray(annotator.im)
  1043. annotator.im.save(fname) # save
  1044. if on_plot:
  1045. on_plot(fname)
  1046. @plt_settings()
  1047. def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
  1048. """
  1049. Plot training results from a results CSV file. The function supports various types of data including segmentation,
  1050. pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
  1051. Args:
  1052. file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.
  1053. dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.
  1054. segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.
  1055. pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.
  1056. classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.
  1057. on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
  1058. Defaults to None.
  1059. Example:
  1060. ```python
  1061. from ultralytics.utils.plotting import plot_results
  1062. plot_results("path/to/results.csv", segment=True)
  1063. ```
  1064. """
  1065. import pandas as pd # scope for faster 'import ultralytics'
  1066. from scipy.ndimage import gaussian_filter1d
  1067. save_dir = Path(file).parent if file else Path(dir)
  1068. if classify:
  1069. fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
  1070. index = [2, 5, 3, 4]
  1071. elif segment:
  1072. fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
  1073. index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
  1074. elif pose:
  1075. fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
  1076. index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
  1077. else:
  1078. fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
  1079. index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
  1080. ax = ax.ravel()
  1081. files = list(save_dir.glob("results*.csv"))
  1082. assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
  1083. for f in files:
  1084. try:
  1085. data = pd.read_csv(f)
  1086. s = [x.strip() for x in data.columns]
  1087. x = data.values[:, 0]
  1088. for i, j in enumerate(index):
  1089. y = data.values[:, j].astype("float")
  1090. # y[y == 0] = np.nan # don't show zero values
  1091. ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
  1092. ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
  1093. ax[i].set_title(s[j], fontsize=12)
  1094. # if j in {8, 9, 10}: # share train and val loss y axes
  1095. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  1096. except Exception as e:
  1097. LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
  1098. ax[1].legend()
  1099. fname = save_dir / "results.png"
  1100. fig.savefig(fname, dpi=200)
  1101. plt.close()
  1102. if on_plot:
  1103. on_plot(fname)
  1104. def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
  1105. """
  1106. Plots a scatter plot with points colored based on a 2D histogram.
  1107. Args:
  1108. v (array-like): Values for the x-axis.
  1109. f (array-like): Values for the y-axis.
  1110. bins (int, optional): Number of bins for the histogram. Defaults to 20.
  1111. cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'.
  1112. alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8.
  1113. edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'.
  1114. Examples:
  1115. >>> v = np.random.rand(100)
  1116. >>> f = np.random.rand(100)
  1117. >>> plt_color_scatter(v, f)
  1118. """
  1119. # Calculate 2D histogram and corresponding colors
  1120. hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
  1121. colors = [
  1122. hist[
  1123. min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
  1124. min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
  1125. ]
  1126. for i in range(len(v))
  1127. ]
  1128. # Scatter plot
  1129. plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
  1130. def plot_tune_results(csv_file="tune_results.csv"):
  1131. """
  1132. Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
  1133. in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
  1134. Args:
  1135. csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.
  1136. Examples:
  1137. >>> plot_tune_results("path/to/tune_results.csv")
  1138. """
  1139. import pandas as pd # scope for faster 'import ultralytics'
  1140. from scipy.ndimage import gaussian_filter1d
  1141. def _save_one_file(file):
  1142. """Save one matplotlib plot to 'file'."""
  1143. plt.savefig(file, dpi=200)
  1144. plt.close()
  1145. LOGGER.info(f"Saved {file}")
  1146. # Scatter plots for each hyperparameter
  1147. csv_file = Path(csv_file)
  1148. data = pd.read_csv(csv_file)
  1149. num_metrics_columns = 1
  1150. keys = [x.strip() for x in data.columns][num_metrics_columns:]
  1151. x = data.values
  1152. fitness = x[:, 0] # fitness
  1153. j = np.argmax(fitness) # max fitness index
  1154. n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
  1155. plt.figure(figsize=(10, 10), tight_layout=True)
  1156. for i, k in enumerate(keys):
  1157. v = x[:, i + num_metrics_columns]
  1158. mu = v[j] # best single result
  1159. plt.subplot(n, n, i + 1)
  1160. plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
  1161. plt.plot(mu, fitness.max(), "k+", markersize=15)
  1162. plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters
  1163. plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
  1164. if i % n != 0:
  1165. plt.yticks([])
  1166. _save_one_file(csv_file.with_name("tune_scatter_plots.png"))
  1167. # Fitness vs iteration
  1168. x = range(1, len(fitness) + 1)
  1169. plt.figure(figsize=(10, 6), tight_layout=True)
  1170. plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
  1171. plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line
  1172. plt.title("Fitness vs Iteration")
  1173. plt.xlabel("Iteration")
  1174. plt.ylabel("Fitness")
  1175. plt.grid(True)
  1176. plt.legend()
  1177. _save_one_file(csv_file.with_name("tune_fitness.png"))
  1178. def output_to_target(output, max_det=300):
  1179. """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
  1180. targets = []
  1181. for i, o in enumerate(output):
  1182. box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
  1183. j = torch.full((conf.shape[0], 1), i)
  1184. targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
  1185. targets = torch.cat(targets, 0).numpy()
  1186. return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
  1187. def output_to_rotated_target(output, max_det=300):
  1188. """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
  1189. targets = []
  1190. for i, o in enumerate(output):
  1191. box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
  1192. j = torch.full((conf.shape[0], 1), i)
  1193. targets.append(torch.cat((j, cls, box, angle, conf), 1))
  1194. targets = torch.cat(targets, 0).numpy()
  1195. return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
  1196. def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
  1197. """
  1198. Visualize feature maps of a given model module during inference.
  1199. Args:
  1200. x (torch.Tensor): Features to be visualized.
  1201. module_type (str): Module type.
  1202. stage (int): Module stage within the model.
  1203. n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
  1204. save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
  1205. """
  1206. for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}: # all model heads
  1207. if m in module_type:
  1208. return
  1209. if isinstance(x, torch.Tensor):
  1210. _, channels, height, width = x.shape # batch, channels, height, width
  1211. if height > 1 and width > 1:
  1212. f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
  1213. blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
  1214. n = min(n, channels) # number of plots
  1215. _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
  1216. ax = ax.ravel()
  1217. plt.subplots_adjust(wspace=0.05, hspace=0.05)
  1218. for i in range(n):
  1219. ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
  1220. ax[i].axis("off")
  1221. LOGGER.info(f"Saving {f}... ({n}/{channels})")
  1222. plt.savefig(f, dpi=300, bbox_inches="tight")
  1223. plt.close()
  1224. np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save