streamlit_inference.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import io
  3. from typing import Any
  4. import cv2
  5. from ultralytics import YOLO
  6. from ultralytics.utils import LOGGER
  7. from ultralytics.utils.checks import check_requirements
  8. from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
  9. class Inference:
  10. """
  11. A class to perform object detection, image classification, image segmentation and pose estimation inference using
  12. Streamlit and Ultralytics YOLO models. It provides the functionalities such as loading models, configuring settings,
  13. uploading video files, and performing real-time inference.
  14. Attributes:
  15. st (module): Streamlit module for UI creation.
  16. temp_dict (dict): Temporary dictionary to store the model path.
  17. model_path (str): Path to the loaded model.
  18. model (YOLO): The YOLO model instance.
  19. source (str): Selected video source.
  20. enable_trk (str): Enable tracking option.
  21. conf (float): Confidence threshold.
  22. iou (float): IoU threshold for non-max suppression.
  23. vid_file_name (str): Name of the uploaded video file.
  24. selected_ind (list): List of selected class indices.
  25. Methods:
  26. web_ui: Sets up the Streamlit web interface with custom HTML elements.
  27. sidebar: Configures the Streamlit sidebar for model and inference settings.
  28. source_upload: Handles video file uploads through the Streamlit interface.
  29. configure: Configures the model and loads selected classes for inference.
  30. inference: Performs real-time object detection inference.
  31. Examples:
  32. >>> inf = solutions.Inference(model="path/to/model.pt") # Model is not necessary argument.
  33. >>> inf.inference()
  34. """
  35. def __init__(self, **kwargs: Any):
  36. """
  37. Initializes the Inference class, checking Streamlit requirements and setting up the model path.
  38. Args:
  39. **kwargs (Any): Additional keyword arguments for model configuration.
  40. """
  41. check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
  42. import streamlit as st
  43. self.st = st # Reference to the Streamlit class instance
  44. self.source = None # Placeholder for video or webcam source details
  45. self.enable_trk = False # Flag to toggle object tracking
  46. self.conf = 0.25 # Confidence threshold for detection
  47. self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
  48. self.org_frame = None # Container for the original frame to be displayed
  49. self.ann_frame = None # Container for the annotated frame to be displayed
  50. self.vid_file_name = None # Holds the name of the video file
  51. self.selected_ind = [] # List of selected classes for detection or tracking
  52. self.model = None # Container for the loaded model instance
  53. self.temp_dict = {"model": None, **kwargs}
  54. self.model_path = None # Store model file name with path
  55. if self.temp_dict["model"] is not None:
  56. self.model_path = self.temp_dict["model"]
  57. LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
  58. def web_ui(self):
  59. """Sets up the Streamlit web interface with custom HTML elements."""
  60. menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
  61. # Main title of streamlit application
  62. main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px; margin-top:-50px;
  63. font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
  64. # Subtitle of streamlit application
  65. sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
  66. margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam with the power
  67. of Ultralytics YOLO! 🚀</h4></div>"""
  68. # Set html page configuration and append custom HTML
  69. self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
  70. self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
  71. self.st.markdown(main_title_cfg, unsafe_allow_html=True)
  72. self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
  73. def sidebar(self):
  74. """Configures the Streamlit sidebar for model and inference settings."""
  75. with self.st.sidebar: # Add Ultralytics LOGO
  76. logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
  77. self.st.image(logo, width=250)
  78. self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
  79. self.source = self.st.sidebar.selectbox(
  80. "Video",
  81. ("webcam", "video"),
  82. ) # Add source selection dropdown
  83. self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
  84. self.conf = float(
  85. self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
  86. ) # Slider for confidence
  87. self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
  88. col1, col2 = self.st.columns(2)
  89. self.org_frame = col1.empty()
  90. self.ann_frame = col2.empty()
  91. def source_upload(self):
  92. """Handles video file uploads through the Streamlit interface."""
  93. self.vid_file_name = ""
  94. if self.source == "video":
  95. vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
  96. if vid_file is not None:
  97. g = io.BytesIO(vid_file.read()) # BytesIO Object
  98. with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
  99. out.write(g.read()) # Read bytes into file
  100. self.vid_file_name = "ultralytics.mp4"
  101. elif self.source == "webcam":
  102. self.vid_file_name = 0
  103. def configure(self):
  104. """Configures the model and loads selected classes for inference."""
  105. # Add dropdown menu for model selection
  106. available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
  107. if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later
  108. available_models.insert(0, self.model_path.split(".pt")[0])
  109. selected_model = self.st.sidebar.selectbox("Model", available_models)
  110. with self.st.spinner("Model is downloading..."):
  111. self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
  112. class_names = list(self.model.names.values()) # Convert dictionary to list of class names
  113. self.st.success("Model loaded successfully!")
  114. # Multiselect box with class names and get indices of selected classes
  115. selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
  116. self.selected_ind = [class_names.index(option) for option in selected_classes]
  117. if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
  118. self.selected_ind = list(self.selected_ind)
  119. def inference(self):
  120. """Performs real-time object detection inference."""
  121. self.web_ui() # Initialize the web interface
  122. self.sidebar() # Create the sidebar
  123. self.source_upload() # Upload the video source
  124. self.configure() # Configure the app
  125. if self.st.sidebar.button("Start"):
  126. stop_button = self.st.button("Stop") # Button to stop the inference
  127. cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
  128. if not cap.isOpened():
  129. self.st.error("Could not open webcam.")
  130. while cap.isOpened():
  131. success, frame = cap.read()
  132. if not success:
  133. self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.")
  134. break
  135. # Store model predictions
  136. if self.enable_trk == "Yes":
  137. results = self.model.track(
  138. frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
  139. )
  140. else:
  141. results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
  142. annotated_frame = results[0].plot() # Add annotations on frame
  143. if stop_button:
  144. cap.release() # Release the capture
  145. self.st.stop() # Stop streamlit app
  146. self.org_frame.image(frame, channels="BGR") # Display original frame
  147. self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame
  148. cap.release() # Release the capture
  149. cv2.destroyAllWindows() # Destroy window
  150. if __name__ == "__main__":
  151. import sys # Import the sys module for accessing command-line arguments
  152. # Check if a model name is provided as a command-line argument
  153. args = len(sys.argv)
  154. model = sys.argv[1] if args > 1 else None # assign first argument as the model name
  155. # Create an instance of the Inference class and run inference
  156. Inference(model=model).inference()