analytics.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from itertools import cycle
  3. import cv2
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
  7. from matplotlib.figure import Figure
  8. from ultralytics.solutions.solutions import BaseSolution # Import a parent class
  9. class Analytics(BaseSolution):
  10. """
  11. A class for creating and updating various types of charts for visual analytics.
  12. This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts
  13. based on object detection and tracking data.
  14. Attributes:
  15. type (str): The type of analytics chart to generate ('line', 'bar', 'pie', or 'area').
  16. x_label (str): Label for the x-axis.
  17. y_label (str): Label for the y-axis.
  18. bg_color (str): Background color of the chart frame.
  19. fg_color (str): Foreground color of the chart frame.
  20. title (str): Title of the chart window.
  21. max_points (int): Maximum number of data points to display on the chart.
  22. fontsize (int): Font size for text display.
  23. color_cycle (cycle): Cyclic iterator for chart colors.
  24. total_counts (int): Total count of detected objects (used for line charts).
  25. clswise_count (Dict[str, int]): Dictionary for class-wise object counts.
  26. fig (Figure): Matplotlib figure object for the chart.
  27. ax (Axes): Matplotlib axes object for the chart.
  28. canvas (FigureCanvas): Canvas for rendering the chart.
  29. Methods:
  30. process_data: Processes image data and updates the chart.
  31. update_graph: Updates the chart with new data points.
  32. Examples:
  33. >>> analytics = Analytics(analytics_type="line")
  34. >>> frame = cv2.imread("image.jpg")
  35. >>> processed_frame = analytics.process_data(frame, frame_number=1)
  36. >>> cv2.imshow("Analytics", processed_frame)
  37. """
  38. def __init__(self, **kwargs):
  39. """Initialize Analytics class with various chart types for visual data representation."""
  40. super().__init__(**kwargs)
  41. self.type = self.CFG["analytics_type"] # extract type of analytics
  42. self.x_label = "Classes" if self.type in {"bar", "pie"} else "Frame#"
  43. self.y_label = "Total Counts"
  44. # Predefined data
  45. self.bg_color = "#F3F3F3" # background color of frame
  46. self.fg_color = "#111E68" # foreground color of frame
  47. self.title = "Ultralytics Solutions" # window name
  48. self.max_points = 45 # maximum points to be drawn on window
  49. self.fontsize = 25 # text font size for display
  50. figsize = (19.2, 10.8) # Set output image size 1920 * 1080
  51. self.color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"])
  52. self.total_counts = 0 # count variable for storing total counts i.e. for line
  53. self.clswise_count = {} # dictionary for class-wise counts
  54. # Ensure line and area chart
  55. if self.type in {"line", "area"}:
  56. self.lines = {}
  57. self.fig = Figure(facecolor=self.bg_color, figsize=figsize)
  58. self.canvas = FigureCanvas(self.fig) # Set common axis properties
  59. self.ax = self.fig.add_subplot(111, facecolor=self.bg_color)
  60. if self.type == "line":
  61. (self.line,) = self.ax.plot([], [], color="cyan", linewidth=self.line_width)
  62. elif self.type in {"bar", "pie"}:
  63. # Initialize bar or pie plot
  64. self.fig, self.ax = plt.subplots(figsize=figsize, facecolor=self.bg_color)
  65. self.canvas = FigureCanvas(self.fig) # Set common axis properties
  66. self.ax.set_facecolor(self.bg_color)
  67. self.color_mapping = {}
  68. if self.type == "pie": # Ensure pie chart is circular
  69. self.ax.axis("equal")
  70. def process_data(self, im0, frame_number):
  71. """
  72. Processes image data and runs object tracking to update analytics charts.
  73. Args:
  74. im0 (np.ndarray): Input image for processing.
  75. frame_number (int): Video frame number for plotting the data.
  76. Returns:
  77. (np.ndarray): Processed image with updated analytics chart.
  78. Raises:
  79. ModuleNotFoundError: If an unsupported chart type is specified.
  80. Examples:
  81. >>> analytics = Analytics(analytics_type="line")
  82. >>> frame = np.zeros((480, 640, 3), dtype=np.uint8)
  83. >>> processed_frame = analytics.process_data(frame, frame_number=1)
  84. """
  85. self.extract_tracks(im0) # Extract tracks
  86. if self.type == "line":
  87. for _ in self.boxes:
  88. self.total_counts += 1
  89. im0 = self.update_graph(frame_number=frame_number)
  90. self.total_counts = 0
  91. elif self.type in {"pie", "bar", "area"}:
  92. self.clswise_count = {}
  93. for box, cls in zip(self.boxes, self.clss):
  94. if self.names[int(cls)] in self.clswise_count:
  95. self.clswise_count[self.names[int(cls)]] += 1
  96. else:
  97. self.clswise_count[self.names[int(cls)]] = 1
  98. im0 = self.update_graph(frame_number=frame_number, count_dict=self.clswise_count, plot=self.type)
  99. else:
  100. raise ModuleNotFoundError(f"{self.type} chart is not supported ❌")
  101. return im0
  102. def update_graph(self, frame_number, count_dict=None, plot="line"):
  103. """
  104. Updates the graph with new data for single or multiple classes.
  105. Args:
  106. frame_number (int): The current frame number.
  107. count_dict (Dict[str, int] | None): Dictionary with class names as keys and counts as values for multiple
  108. classes. If None, updates a single line graph.
  109. plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'.
  110. Returns:
  111. (np.ndarray): Updated image containing the graph.
  112. Examples:
  113. >>> analytics = Analytics()
  114. >>> frame_number = 10
  115. >>> count_dict = {"person": 5, "car": 3}
  116. >>> updated_image = analytics.update_graph(frame_number, count_dict, plot="bar")
  117. """
  118. if count_dict is None:
  119. # Single line update
  120. x_data = np.append(self.line.get_xdata(), float(frame_number))
  121. y_data = np.append(self.line.get_ydata(), float(self.total_counts))
  122. if len(x_data) > self.max_points:
  123. x_data, y_data = x_data[-self.max_points :], y_data[-self.max_points :]
  124. self.line.set_data(x_data, y_data)
  125. self.line.set_label("Counts")
  126. self.line.set_color("#7b0068") # Pink color
  127. self.line.set_marker("*")
  128. self.line.set_markersize(self.line_width * 5)
  129. else:
  130. labels = list(count_dict.keys())
  131. counts = list(count_dict.values())
  132. if plot == "area":
  133. color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"])
  134. # Multiple lines or area update
  135. x_data = self.ax.lines[0].get_xdata() if self.ax.lines else np.array([])
  136. y_data_dict = {key: np.array([]) for key in count_dict.keys()}
  137. if self.ax.lines:
  138. for line, key in zip(self.ax.lines, count_dict.keys()):
  139. y_data_dict[key] = line.get_ydata()
  140. x_data = np.append(x_data, float(frame_number))
  141. max_length = len(x_data)
  142. for key in count_dict.keys():
  143. y_data_dict[key] = np.append(y_data_dict[key], float(count_dict[key]))
  144. if len(y_data_dict[key]) < max_length:
  145. y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])))
  146. if len(x_data) > self.max_points:
  147. x_data = x_data[1:]
  148. for key in count_dict.keys():
  149. y_data_dict[key] = y_data_dict[key][1:]
  150. self.ax.clear()
  151. for key, y_data in y_data_dict.items():
  152. color = next(color_cycle)
  153. self.ax.fill_between(x_data, y_data, color=color, alpha=0.7)
  154. self.ax.plot(
  155. x_data,
  156. y_data,
  157. color=color,
  158. linewidth=self.line_width,
  159. marker="o",
  160. markersize=self.line_width * 5,
  161. label=f"{key} Data Points",
  162. )
  163. if plot == "bar":
  164. self.ax.clear() # clear bar data
  165. for label in labels: # Map labels to colors
  166. if label not in self.color_mapping:
  167. self.color_mapping[label] = next(self.color_cycle)
  168. colors = [self.color_mapping[label] for label in labels]
  169. bars = self.ax.bar(labels, counts, color=colors)
  170. for bar, count in zip(bars, counts):
  171. self.ax.text(
  172. bar.get_x() + bar.get_width() / 2,
  173. bar.get_height(),
  174. str(count),
  175. ha="center",
  176. va="bottom",
  177. color=self.fg_color,
  178. )
  179. # Create the legend using labels from the bars
  180. for bar, label in zip(bars, labels):
  181. bar.set_label(label) # Assign label to each bar
  182. self.ax.legend(loc="upper left", fontsize=13, facecolor=self.fg_color, edgecolor=self.fg_color)
  183. if plot == "pie":
  184. total = sum(counts)
  185. percentages = [size / total * 100 for size in counts]
  186. start_angle = 90
  187. self.ax.clear()
  188. # Create pie chart and create legend labels with percentages
  189. wedges, autotexts = self.ax.pie(
  190. counts, labels=labels, startangle=start_angle, textprops={"color": self.fg_color}, autopct=None
  191. )
  192. legend_labels = [f"{label} ({percentage:.1f}%)" for label, percentage in zip(labels, percentages)]
  193. # Assign the legend using the wedges and manually created labels
  194. self.ax.legend(wedges, legend_labels, title="Classes", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))
  195. self.fig.subplots_adjust(left=0.1, right=0.75) # Adjust layout to fit the legend
  196. # Common plot settings
  197. self.ax.set_facecolor("#f0f0f0") # Set to light gray or any other color you like
  198. self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize)
  199. self.ax.set_xlabel(self.x_label, color=self.fg_color, fontsize=self.fontsize - 3)
  200. self.ax.set_ylabel(self.y_label, color=self.fg_color, fontsize=self.fontsize - 3)
  201. # Add and format legend
  202. legend = self.ax.legend(loc="upper left", fontsize=13, facecolor=self.bg_color, edgecolor=self.bg_color)
  203. for text in legend.get_texts():
  204. text.set_color(self.fg_color)
  205. # Redraw graph, update view, capture, and display the updated plot
  206. self.ax.relim()
  207. self.ax.autoscale_view()
  208. self.canvas.draw()
  209. im0 = np.array(self.canvas.renderer.buffer_rgba())
  210. im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
  211. self.display_output(im0)
  212. return im0 # Return the image