123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
- import io
- from typing import Any
- import cv2
- from ultralytics import YOLO
- from ultralytics.utils import LOGGER
- from ultralytics.utils.checks import check_requirements
- from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
- class Inference:
- """
- A class to perform object detection, image classification, image segmentation and pose estimation inference using
- Streamlit and Ultralytics YOLO models. It provides the functionalities such as loading models, configuring settings,
- uploading video files, and performing real-time inference.
- Attributes:
- st (module): Streamlit module for UI creation.
- temp_dict (dict): Temporary dictionary to store the model path.
- model_path (str): Path to the loaded model.
- model (YOLO): The YOLO model instance.
- source (str): Selected video source.
- enable_trk (str): Enable tracking option.
- conf (float): Confidence threshold.
- iou (float): IoU threshold for non-max suppression.
- vid_file_name (str): Name of the uploaded video file.
- selected_ind (list): List of selected class indices.
- Methods:
- web_ui: Sets up the Streamlit web interface with custom HTML elements.
- sidebar: Configures the Streamlit sidebar for model and inference settings.
- source_upload: Handles video file uploads through the Streamlit interface.
- configure: Configures the model and loads selected classes for inference.
- inference: Performs real-time object detection inference.
- Examples:
- >>> inf = solutions.Inference(model="path/to/model.pt") # Model is not necessary argument.
- >>> inf.inference()
- """
- def __init__(self, **kwargs: Any):
- """
- Initializes the Inference class, checking Streamlit requirements and setting up the model path.
- Args:
- **kwargs (Any): Additional keyword arguments for model configuration.
- """
- check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
- import streamlit as st
- self.st = st # Reference to the Streamlit class instance
- self.source = None # Placeholder for video or webcam source details
- self.enable_trk = False # Flag to toggle object tracking
- self.conf = 0.25 # Confidence threshold for detection
- self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
- self.org_frame = None # Container for the original frame to be displayed
- self.ann_frame = None # Container for the annotated frame to be displayed
- self.vid_file_name = None # Holds the name of the video file
- self.selected_ind = [] # List of selected classes for detection or tracking
- self.model = None # Container for the loaded model instance
- self.temp_dict = {"model": None, **kwargs}
- self.model_path = None # Store model file name with path
- if self.temp_dict["model"] is not None:
- self.model_path = self.temp_dict["model"]
- LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
- def web_ui(self):
- """Sets up the Streamlit web interface with custom HTML elements."""
- menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
- # Main title of streamlit application
- main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px; margin-top:-50px;
- font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
- # Subtitle of streamlit application
- sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
- margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam with the power
- of Ultralytics YOLO! 🚀</h4></div>"""
- # Set html page configuration and append custom HTML
- self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
- self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
- self.st.markdown(main_title_cfg, unsafe_allow_html=True)
- self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
- def sidebar(self):
- """Configures the Streamlit sidebar for model and inference settings."""
- with self.st.sidebar: # Add Ultralytics LOGO
- logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
- self.st.image(logo, width=250)
- self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
- self.source = self.st.sidebar.selectbox(
- "Video",
- ("webcam", "video"),
- ) # Add source selection dropdown
- self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
- self.conf = float(
- self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
- ) # Slider for confidence
- self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
- col1, col2 = self.st.columns(2)
- self.org_frame = col1.empty()
- self.ann_frame = col2.empty()
- def source_upload(self):
- """Handles video file uploads through the Streamlit interface."""
- self.vid_file_name = ""
- if self.source == "video":
- vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
- if vid_file is not None:
- g = io.BytesIO(vid_file.read()) # BytesIO Object
- with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
- out.write(g.read()) # Read bytes into file
- self.vid_file_name = "ultralytics.mp4"
- elif self.source == "webcam":
- self.vid_file_name = 0
- def configure(self):
- """Configures the model and loads selected classes for inference."""
- # Add dropdown menu for model selection
- available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
- if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later
- available_models.insert(0, self.model_path.split(".pt")[0])
- selected_model = self.st.sidebar.selectbox("Model", available_models)
- with self.st.spinner("Model is downloading..."):
- self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
- class_names = list(self.model.names.values()) # Convert dictionary to list of class names
- self.st.success("Model loaded successfully!")
- # Multiselect box with class names and get indices of selected classes
- selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
- self.selected_ind = [class_names.index(option) for option in selected_classes]
- if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
- self.selected_ind = list(self.selected_ind)
- def inference(self):
- """Performs real-time object detection inference."""
- self.web_ui() # Initialize the web interface
- self.sidebar() # Create the sidebar
- self.source_upload() # Upload the video source
- self.configure() # Configure the app
- if self.st.sidebar.button("Start"):
- stop_button = self.st.button("Stop") # Button to stop the inference
- cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
- if not cap.isOpened():
- self.st.error("Could not open webcam.")
- while cap.isOpened():
- success, frame = cap.read()
- if not success:
- self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.")
- break
- # Store model predictions
- if self.enable_trk == "Yes":
- results = self.model.track(
- frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
- )
- else:
- results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
- annotated_frame = results[0].plot() # Add annotations on frame
- if stop_button:
- cap.release() # Release the capture
- self.st.stop() # Stop streamlit app
- self.org_frame.image(frame, channels="BGR") # Display original frame
- self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame
- cap.release() # Release the capture
- cv2.destroyAllWindows() # Destroy window
- if __name__ == "__main__":
- import sys # Import the sys module for accessing command-line arguments
- # Check if a model name is provided as a command-line argument
- args = len(sys.argv)
- model = sys.argv[1] if args > 1 else None # assign first argument as the model name
- # Create an instance of the Inference class and run inference
- Inference(model=model).inference()
|