diff --git a/.github/workflows/integration_e2e_tests_inference_sdk_x86.yml b/.github/workflows/integration_e2e_tests_inference_sdk_x86.yml new file mode 100644 index 0000000000..351133fddd --- /dev/null +++ b/.github/workflows/integration_e2e_tests_inference_sdk_x86.yml @@ -0,0 +1,98 @@ +name: INTEGRATION & E2E TESTS - inference SDK +permissions: + contents: read +on: + pull_request: + branches: [main] + push: + branches: [main] + workflow_dispatch: + +jobs: + call_is_mergeable: + uses: ./.github/workflows/check_if_branch_is_mergeable.yml + secrets: inherit + + integration-tests: + needs: call_is_mergeable + if: ${{ github.event_name != 'pull_request' || needs.call_is_mergeable.outputs.mergeable_state != 'not_clean' }} + runs-on: + labels: depot-ubuntu-22.04-4 + group: public-depot + timeout-minutes: 20 + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: + - name: ๐Ÿ›Ž๏ธ Checkout + uses: actions/checkout@v4 + + - name: ๐Ÿ“ฆ Cache Python packages + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-integration-${{ hashFiles('requirements/**') }} + restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}- + + - name: ๐Ÿ Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + check-latest: true + + - name: ๐Ÿ“ฆ Install dependencies + run: | + python -m pip install --upgrade pip + pip install --upgrade setuptools + pip install -r requirements/_requirements.txt \ + -r requirements/requirements.sdk.http.txt \ + -r requirements/requirements.sdk.webrtc.txt \ + -r requirements/requirements.test.unit.txt + + - name: ๐Ÿงช Integration Tests - WebRTC SDK (mocked) + timeout-minutes: 10 + run: python -m pytest tests/inference_sdk/integration_tests -v --timeout=300 + + e2e-tests: + needs: call_is_mergeable + if: ${{ github.event_name != 'pull_request' || needs.call_is_mergeable.outputs.mergeable_state != 'not_clean' }} + runs-on: + labels: depot-ubuntu-22.04-4 + group: public-depot + timeout-minutes: 30 + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: + - name: ๐Ÿ›Ž๏ธ Checkout + uses: actions/checkout@v4 + + - name: ๐Ÿ“ฆ Cache Python packages + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-e2e-${{ hashFiles('requirements/**') }} + restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}- + + - name: ๐Ÿ Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + check-latest: true + + - name: ๐Ÿ“ฆ Install dependencies + run: | + python -m pip install --upgrade pip + pip install --upgrade setuptools + pip install -r requirements/_requirements.txt \ + -r requirements/requirements.cpu.txt \ + -r requirements/requirements.http.txt \ + -r requirements/requirements.sdk.http.txt \ + -r requirements/requirements.sdk.webrtc.txt \ + -r requirements/requirements.test.unit.txt + + - name: ๐Ÿงช E2E Tests - WebRTC SDK (real server) + timeout-minutes: 15 + run: python -m pytest tests/inference_sdk/e2e_tests -m slow -v --timeout=300 diff --git a/.github/workflows/unit_tests_inference_sdk_x86.yml b/.github/workflows/unit_tests_inference_sdk_x86.yml index ca555d0660..683e0fdc81 100644 --- a/.github/workflows/unit_tests_inference_sdk_x86.yml +++ b/.github/workflows/unit_tests_inference_sdk_x86.yml @@ -41,7 +41,7 @@ jobs: run: | python -m pip install --upgrade pip pip install --upgrade setuptools - pip install -r requirements/requirements.sdk.http.txt -r requirements/requirements.test.unit.txt + pip install -r requirements/requirements.sdk.http.txt -r requirements/requirements.sdk.webrtc.txt -r requirements/requirements.test.unit.txt - name: ๐Ÿงช Unit Tests of Inference SDK timeout-minutes: 30 run: python -m pytest tests/inference_sdk/unit_tests diff --git a/.release/pypi/inference.sdk.setup.py b/.release/pypi/inference.sdk.setup.py index 9d44f6b0e7..fcc88cbd23 100644 --- a/.release/pypi/inference.sdk.setup.py +++ b/.release/pypi/inference.sdk.setup.py @@ -57,6 +57,9 @@ def read_requirements(path): ), ), install_requires=read_requirements(["requirements/requirements.sdk.http.txt"]), + extras_require={ + "webrtc": read_requirements(["requirements/requirements.sdk.webrtc.txt"]), + }, classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", diff --git a/examples/webrtc_sdk/rtsp_basic.py b/examples/webrtc_sdk/rtsp_basic.py new file mode 100644 index 0000000000..940fc15bd7 --- /dev/null +++ b/examples/webrtc_sdk/rtsp_basic.py @@ -0,0 +1,99 @@ +""" +Minimal sample using the SDK's WebRTC namespace to stream RTSP frames +to a running inference server with WebRTC worker enabled. + +Usage: + python examples/webrtc_sdk/rtsp_basic.py \\ + --rtsp-url rtsp://camera.local/stream \\ + --workspace-name \\ + --workflow-id \\ + [--api-url http://localhost:9001] \\ + [--api-key ] \\ + [--stream-output ] \\ + [--data-output ] + +Press 'q' in the preview window to exit. +""" +import argparse + +import av +import cv2 + +from inference_sdk import InferenceHTTPClient +from inference_sdk.webrtc import RTSPSource, StreamConfig, VideoMetadata + +# Suppress FFmpeg warnings from PyAV +av.logging.set_level(av.logging.ERROR) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser("WebRTC SDK rtsp_basic") + p.add_argument( + "--rtsp-url", + required=True, + help="RTSP stream URL (e.g., rtsp://camera.local/stream)", + ) + p.add_argument("--api-url", default="https://serverless.roboflow.com") + p.add_argument("--workspace-name", required=True) + p.add_argument("--workflow-id", required=True) + p.add_argument("--image-input-name", default="image") + p.add_argument("--api-key", default=None) + p.add_argument( + "--stream-output", + default=None, + help="Name of the workflow output to stream (e.g., 'image_output')", + ) + p.add_argument( + "--data-output", + default=None, + help="Name of the workflow output to receive via data channel", + ) + return p.parse_args() + + +def main() -> None: + args = parse_args() + client = InferenceHTTPClient.init(api_url=args.api_url, api_key=args.api_key) + + # Prepare source + source = RTSPSource(args.rtsp_url) + + # Prepare config + stream_output = [args.stream_output] if args.stream_output else [] + data_output = [args.data_output] if args.data_output else [] + config = StreamConfig(stream_output=stream_output, data_output=data_output) + + # Create streaming session + session = client.webrtc.stream( + source=source, + workflow=args.workflow_id, + workspace=args.workspace_name, + image_input=args.image_input_name, + config=config, + ) + + # Register frame handler + @session.on_frame + def show_frame(frame, metadata): + cv2.imshow("WebRTC SDK - RTSP", frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + session.close() # Close session and cleanup resources + + # Register data handlers + # Global handler (receives entire serialized_output_data dict + metadata) + @session.on_data() + def on_message(data: dict, metadata: VideoMetadata): + print(f"Frame {metadata.frame_id}: {data}") + + # Field-specific handler example (uncomment and customize based on your workflow): + # @session.on_data("predictions") + # def on_predictions(predictions: dict, metadata: VideoMetadata): + # print(f"Frame {metadata.frame_id} predictions: {predictions}") + + # Run the session (auto-starts, blocks until close() is called or stream ends) + # Automatically closes on exception or when stream ends + session.run() + + +if __name__ == "__main__": + main() diff --git a/examples/webrtc_sdk/video_file_basic.py b/examples/webrtc_sdk/video_file_basic.py new file mode 100644 index 0000000000..3e9a7fb1c4 --- /dev/null +++ b/examples/webrtc_sdk/video_file_basic.py @@ -0,0 +1,97 @@ +""" +Minimal sample using the SDK's WebRTC namespace to stream video file frames +to a running inference server with WebRTC worker enabled. + +Usage: + python examples/webrtc_sdk/video_file_basic.py \\ + --video-path /path/to/video.mp4 \\ + --workspace-name \\ + --workflow-id \\ + [--api-url http://localhost:9001] \\ + [--api-key ] \\ + [--stream-output ] \\ + [--data-output ] + +Press 'q' in the preview window to exit. +""" +import argparse + +import av +import cv2 + +from inference_sdk import InferenceHTTPClient +from inference_sdk.webrtc import StreamConfig, VideoFileSource, VideoMetadata + +# Suppress FFmpeg warnings from PyAV +av.logging.set_level(av.logging.ERROR) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser("WebRTC SDK video_file_basic") + p.add_argument( + "--video-path", required=True, help="Path to video file to process" + ) + p.add_argument("--api-url", default="https://serverless.roboflow.com") + p.add_argument("--workspace-name", required=True) + p.add_argument("--workflow-id", required=True) + p.add_argument("--image-input-name", default="image") + p.add_argument("--api-key", default=None) + p.add_argument( + "--stream-output", + default=None, + help="Name of the workflow output to stream (e.g., 'image_output')", + ) + p.add_argument( + "--data-output", + default=None, + help="Name of the workflow output to receive via data channel", + ) + return p.parse_args() + + +def main() -> None: + args = parse_args() + client = InferenceHTTPClient.init(api_url=args.api_url, api_key=args.api_key) + + # Prepare source + source = VideoFileSource(args.video_path) + + # Prepare config + stream_output = [args.stream_output] if args.stream_output else [] + data_output = [args.data_output] if args.data_output else [] + config = StreamConfig(stream_output=stream_output, data_output=data_output) + + # Create streaming session + session = client.webrtc.stream( + source=source, + workflow=args.workflow_id, + workspace=args.workspace_name, + image_input=args.image_input_name, + config=config, + ) + + # Register frame handler + @session.on_frame + def show_frame(frame, metadata): + cv2.imshow("WebRTC SDK - Video File", frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + session.close() # Close session and cleanup resources + + # Register data handlers + # Global handler (receives entire serialized_output_data dict + metadata) + @session.on_data() + def on_message(data: dict, metadata: VideoMetadata): + print(f"Frame {metadata.frame_id}: {data}") + + # Field-specific handler example (uncomment and customize based on your workflow): + # @session.on_data("predictions") + # def on_predictions(predictions: dict, metadata: VideoMetadata): + # print(f"Frame {metadata.frame_id} predictions: {predictions}") + + # Run the session (auto-starts, blocks until close() is called or stream ends) + # Automatically closes on exception or when stream ends + session.run() + + +if __name__ == "__main__": + main() diff --git a/examples/webrtc_sdk/webcam_basic.py b/examples/webrtc_sdk/webcam_basic.py new file mode 100644 index 0000000000..cedc1ccda2 --- /dev/null +++ b/examples/webrtc_sdk/webcam_basic.py @@ -0,0 +1,96 @@ +""" +Minimal sample using the SDK's WebRTC namespace to stream webcam frames +to a running inference server with WebRTC worker enabled. + +Usage: + python examples/webrtc_sdk/webcam_basic.py \\ + --workspace-name \\ + --workflow-id \\ + [--api-url http://localhost:9001] \\ + [--api-key ] \\ + [--width 1920] \\ + [--height 1080] \\ + [--stream-output ] \\ + [--data-output ] + +Press 'q' in the preview window to exit. +""" +import argparse + +import cv2 + +from inference_sdk import InferenceHTTPClient +from inference_sdk.webrtc import VideoMetadata, WebcamSource, StreamConfig + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser("WebRTC SDK webcam_basic") + p.add_argument("--api-url", default="https://serverless.roboflow.com") + p.add_argument("--workspace-name", required=True) + p.add_argument("--workflow-id", required=True) + p.add_argument("--image-input-name", default="image") + p.add_argument("--api-key", default=None) + p.add_argument("--width", type=int, default=None) + p.add_argument("--height", type=int, default=None) + p.add_argument( + "--stream-output", + default=None, + help="Name of the workflow output to stream (e.g., 'image_output')", + ) + p.add_argument( + "--data-output", + default=None, + help="Name of the workflow output to receive via data channel", + ) + return p.parse_args() + + +def main() -> None: + args = parse_args() + client = InferenceHTTPClient.init(api_url=args.api_url, api_key=args.api_key) + + # Prepare source + resolution = None + if args.width and args.height: + resolution = (args.width, args.height) + source = WebcamSource(resolution=resolution) + + # Prepare config + stream_output = [args.stream_output] if args.stream_output else [] + data_output = [args.data_output] if args.data_output else [] + config = StreamConfig(stream_output=stream_output, data_output=data_output) + + # Create streaming session + session = client.webrtc.stream( + source=source, + workflow=args.workflow_id, + workspace=args.workspace_name, + image_input=args.image_input_name, + config=config, + ) + + # Register frame handler + @session.on_frame + def show_frame(frame, metadata): + cv2.imshow("WebRTC SDK - Webcam", frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + session.close() # Close session and cleanup resources + + # Register data handlers + # Global handler (receives entire serialized_output_data dict + metadata) + @session.on_data() + def on_message(data: dict, metadata: VideoMetadata): + print(f"Frame {metadata.frame_id}: {data}") + + # Field-specific handler example (uncomment and customize based on your workflow): + # @session.on_data("predictions") + # def on_predictions(predictions: dict, metadata: VideoMetadata): + # print(f"Frame {metadata.frame_id} predictions: {predictions}") + + # Run the session (auto-starts, blocks until close() is called or stream ends) + # Automatically closes on exception or when stream ends + session.run() + + +if __name__ == "__main__": + main() diff --git a/inference_sdk/config.py b/inference_sdk/config.py index a802916888..6846e59cc4 100644 --- a/inference_sdk/config.py +++ b/inference_sdk/config.py @@ -11,6 +11,13 @@ EXECUTION_ID_HEADER = os.getenv("EXECUTION_ID_HEADER", "execution_id") PROCESSING_TIME_HEADER = os.getenv("PROCESSING_TIME_HEADER", "X-Processing-Time") +# WebRTC configuration +WEBRTC_INITIAL_FRAME_TIMEOUT = float(os.getenv("WEBRTC_INITIAL_FRAME_TIMEOUT", "90.0")) +WEBRTC_VIDEO_QUEUE_MAX_SIZE = int(os.getenv("WEBRTC_VIDEO_QUEUE_MAX_SIZE", "8")) +WEBRTC_EVENT_LOOP_SHUTDOWN_TIMEOUT = float( + os.getenv("WEBRTC_EVENT_LOOP_SHUTDOWN_TIMEOUT", "2.0") +) + class InferenceSDKDeprecationWarning(Warning): """Class used for warning of deprecated features in the Inference SDK""" diff --git a/inference_sdk/http/client.py b/inference_sdk/http/client.py index 868ff14914..43cb88783e 100644 --- a/inference_sdk/http/client.py +++ b/inference_sdk/http/client.py @@ -1,5 +1,15 @@ from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + List, + Literal, + Optional, + Tuple, + Union, +) import aiohttp import numpy as np @@ -91,6 +101,9 @@ ] BufferConsumptionStrategy = Literal["LAZY", "EAGER"] +if TYPE_CHECKING: + from inference_sdk.webrtc.client import WebRTCClient + def wrap_errors(function: callable) -> callable: def decorate(*args, **kwargs) -> Any: @@ -216,6 +229,7 @@ def __init__( self.__inference_configuration = InferenceConfiguration.init_default() self.__client_mode = _determine_client_mode(api_url=api_url) self.__selected_model: Optional[str] = None + self.__webrtc_client: Optional["WebRTCClient"] = None @property def inference_configuration(self) -> InferenceConfiguration: @@ -244,6 +258,19 @@ def selected_model(self) -> Optional[str]: """ return self.__selected_model + @property + def webrtc(self) -> "WebRTCClient": + """Lazy accessor for the WebRTC client namespace. + + Returns: + WebRTCClient: Namespaced WebRTC API bound to this HTTP client. + """ + from inference_sdk.webrtc.client import WebRTCClient + + if self.__webrtc_client is None: + self.__webrtc_client = WebRTCClient(self.__api_url, self.__api_key) + return self.__webrtc_client + @contextmanager def use_configuration( self, inference_configuration: InferenceConfiguration diff --git a/inference_sdk/utils/logging.py b/inference_sdk/utils/logging.py new file mode 100644 index 0000000000..5a4891e632 --- /dev/null +++ b/inference_sdk/utils/logging.py @@ -0,0 +1,37 @@ +"""Centralized logging configuration for the Inference SDK.""" + +import logging +import sys + +# SDK-wide logger name +SDK_LOGGER_NAME = "inference_sdk" + +# Track if we've already configured +_configured = False + + +def get_logger(module_name: str) -> logging.Logger: + """Get a logger for the specified module. + + Automatically configures basic logging on first use if no handlers exist. + + Args: + module_name: Name of the module requesting the logger. + + Returns: + logging.Logger: Configured logger for the module. + """ + global _configured + + sdk_logger = logging.getLogger(SDK_LOGGER_NAME) + + # Configure basic logging on first use if needed + if not _configured and not sdk_logger.handlers: + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter(logging.Formatter("%(levelname)s [%(name)s] %(message)s")) + sdk_logger.addHandler(handler) + sdk_logger.setLevel(logging.INFO) + sdk_logger.propagate = False + _configured = True + + return logging.getLogger(f"{SDK_LOGGER_NAME}.{module_name}") diff --git a/inference_sdk/webrtc/__init__.py b/inference_sdk/webrtc/__init__.py new file mode 100644 index 0000000000..c9deb7baf8 --- /dev/null +++ b/inference_sdk/webrtc/__init__.py @@ -0,0 +1,26 @@ +"""WebRTC SDK for Inference - Unified streaming API.""" + +from .client import WebRTCClient # noqa: F401 +from .config import StreamConfig # noqa: F401 +from .session import VideoMetadata, WebRTCSession # noqa: F401 +from .sources import ( # noqa: F401 + ManualSource, + RTSPSource, + StreamSource, + VideoFileSource, + WebcamSource, +) + +__all__ = [ + # Core classes + "WebRTCClient", + "WebRTCSession", + "StreamConfig", + "VideoMetadata", + # Source classes + "StreamSource", + "WebcamSource", + "RTSPSource", + "VideoFileSource", + "ManualSource", +] diff --git a/inference_sdk/webrtc/client.py b/inference_sdk/webrtc/client.py new file mode 100644 index 0000000000..aa86ac237e --- /dev/null +++ b/inference_sdk/webrtc/client.py @@ -0,0 +1,143 @@ +"""WebRTC client for the Inference SDK.""" + +from __future__ import annotations + +from typing import Optional, Union + +from inference_sdk.http.errors import InvalidParameterError +from inference_sdk.utils.decorators import experimental +from inference_sdk.webrtc.config import StreamConfig +from inference_sdk.webrtc.session import WebRTCSession +from inference_sdk.webrtc.sources import StreamSource + + +class WebRTCClient: + """Namespaced WebRTC API bound to an InferenceHTTPClient instance. + + Provides a unified streaming interface for different video sources + (webcam, RTSP, video files, manual frames). + """ + + @experimental( + info="WebRTC SDK is experimental and under active development. " + "API may change in future releases. Please report issues at " + "https://github.com/roboflow/inference/issues" + ) + def __init__(self, api_url: str, api_key: Optional[str]) -> None: + """Initialize WebRTC client. + + Args: + api_url: Base URL for the inference API + api_key: API key for authentication (optional) + """ + self._api_url = api_url + self._api_key = api_key + + def stream( + self, + source: StreamSource, + *, + workflow: Union[str, dict], + image_input: str = "image", + workspace: Optional[str] = None, + config: Optional[StreamConfig] = None, + ) -> WebRTCSession: + """Create a WebRTC streaming session. + + Args: + source: Stream source (WebcamSource, RTSPSource, VideoFileSource, or ManualSource) + workflow: Either a workflow ID (str) or workflow specification (dict) + image_input: Name of the image input in the workflow + workspace: Workspace name (required if workflow is an ID string) + config: Stream configuration (output routing, FPS, TURN server, etc.) + + Returns: + WebRTCSession context manager + + Raises: + InvalidParameterError: If workflow/workspace parameters are invalid + + Examples: + # Pattern 1: Using run() with decorators (recommended, auto-cleanup) + from inference_sdk.webrtc import WebcamSource + + session = client.webrtc.stream( + source=WebcamSource(resolution=(1920, 1080)), + workflow="object-detection", + workspace="my-workspace" + ) + + @session.on_frame + def process_frame(frame, metadata): + cv2.imshow("Frame", frame) + if cv2.waitKey(1) & 0xFF == ord('q'): + session.close() + + session.run() # Auto-closes on exception or stream end + + # Pattern 2: Using video() iterator (requires context manager or explicit close) + from inference_sdk.webrtc import RTSPSource + + # Option A: With context manager (recommended) + with client.webrtc.stream( + source=RTSPSource("rtsp://camera.local/stream"), + workflow=workflow_spec_dict + ) as session: + for frame, metadata in session.video(): + cv2.imshow("Frame", frame) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + # Auto-cleanup on exit + + # Option B: Manual cleanup (not recommended) + session = client.webrtc.stream(source=RTSPSource("rtsp://..."), ...) + for frame, metadata in session.video(): + process(frame) + session.close() # Must call close() explicitly! + """ + # Validate workflow configuration + workflow_config = self._parse_workflow_config(workflow, workspace) + + # Use default config if not provided + if config is None: + config = StreamConfig() + + # Create session + return WebRTCSession( + api_url=self._api_url, + api_key=self._api_key, + source=source, + image_input_name=image_input, + workflow_config=workflow_config, + stream_config=config, + ) + + def _parse_workflow_config( + self, workflow: Union[str, dict], workspace: Optional[str] + ) -> dict: + """Parse workflow configuration from inputs. + + Args: + workflow: Either workflow ID (str) or specification (dict) + workspace: Workspace name (required for ID mode) + + Returns: + Dictionary with workflow configuration + + Raises: + InvalidParameterError: If configuration is invalid + """ + if isinstance(workflow, str): + # Workflow ID mode - requires workspace + if not workspace: + raise InvalidParameterError( + "workspace parameter required when workflow is an ID string" + ) + return {"workflow_id": workflow, "workspace_name": workspace} + elif isinstance(workflow, dict): + # Workflow specification mode + return {"workflow_specification": workflow} + else: + raise InvalidParameterError( + f"workflow must be a string (ID) or dict (specification), got {type(workflow)}" + ) diff --git a/inference_sdk/webrtc/config.py b/inference_sdk/webrtc/config.py new file mode 100644 index 0000000000..31a3bfd347 --- /dev/null +++ b/inference_sdk/webrtc/config.py @@ -0,0 +1,45 @@ +"""Configuration for WebRTC streaming sessions.""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class StreamConfig: + """Unified configuration for all WebRTC stream types. + + This configuration applies to all stream sources (webcam, RTSP, video file, manual) + and controls output routing, processing behavior, and network settings. + """ + + # Output configuration + stream_output: List[str] = field(default_factory=list) + """List of workflow output names to stream as video""" + + data_output: List[str] = field(default_factory=list) + """List of workflow output names to receive via data channel""" + + # Processing configuration + realtime_processing: bool = True + """Whether to process frames in realtime (drop if can't keep up) or queue all frames""" + + declared_fps: Optional[float] = None + """Optional FPS declaration for the stream. + + Note: Some sources (like WebcamSource) auto-detect FPS from the video device and will + override this value. The source's detected FPS takes precedence over this configuration. + For sources without auto-detection (like ManualSource), this value will be used if provided. + """ + + # Network configuration + turn_server: Optional[Dict[str, str]] = None + """TURN server configuration: {"urls": "turn:...", "username": "...", "credential": "..."} + + Provide this configuration when your network requires a TURN server for WebRTC connectivity. + TURN is automatically skipped for localhost connections. If not provided, the connection + will attempt to establish directly without TURN relay. + """ + + # Workflow parameters + workflow_parameters: Dict[str, Any] = field(default_factory=dict) + """Parameters to pass to the workflow execution""" diff --git a/inference_sdk/webrtc/datachannel.py b/inference_sdk/webrtc/datachannel.py new file mode 100644 index 0000000000..ef54007d21 --- /dev/null +++ b/inference_sdk/webrtc/datachannel.py @@ -0,0 +1,68 @@ +"""WebRTC data channel binary chunking utilities.""" + +import struct +from typing import Dict, Optional, Tuple + + +def _parse_chunked_binary_message(message: bytes) -> Tuple[int, int, int, bytes]: + """Parse a binary message with standard 12-byte header. + + Format: [frame_id: 4][chunk_index: 4][total_chunks: 4][payload: N] + All integers are uint32 little-endian. + + Returns: (frame_id, chunk_index, total_chunks, payload) + """ + if len(message) < 12: + raise ValueError(f"Message too short: {len(message)} bytes (expected >= 12)") + + frame_id, chunk_index, total_chunks = struct.unpack(" Tuple[Optional[bytes], Optional[int]]: + """Parse and add a chunk, returning complete payload and frame_id if all chunks received. + + Args: + message: Raw binary message with 12-byte header + + Returns: + Tuple of (payload, frame_id) if complete, (None, None) otherwise + """ + # Parse the binary message + frame_id, chunk_index, total_chunks, chunk_data = _parse_chunked_binary_message( + message + ) + + # Initialize buffers for new frame + if frame_id not in self._chunks: + self._chunks[frame_id] = {} + self._total[frame_id] = total_chunks + + # Store chunk + self._chunks[frame_id][chunk_index] = chunk_data + + # Check if all chunks received + if len(self._chunks[frame_id]) >= total_chunks: + # Reassemble in order + complete_payload = b"".join( + self._chunks[frame_id][i] for i in range(total_chunks) + ) + + # Clean up buffers for completed frame - this is the key part! + del self._chunks[frame_id] + del self._total[frame_id] + + return complete_payload, frame_id + + return None, None diff --git a/inference_sdk/webrtc/session.py b/inference_sdk/webrtc/session.py new file mode 100644 index 0000000000..119cb8f91b --- /dev/null +++ b/inference_sdk/webrtc/session.py @@ -0,0 +1,774 @@ +"""WebRTC session management.""" + +import asyncio +import inspect +import json +import queue +import struct +import threading +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from queue import Queue +from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional + +import numpy as np +import requests + +from inference_sdk.config import ( + WEBRTC_EVENT_LOOP_SHUTDOWN_TIMEOUT, + WEBRTC_INITIAL_FRAME_TIMEOUT, + WEBRTC_VIDEO_QUEUE_MAX_SIZE, +) +from inference_sdk.utils.logging import get_logger +from inference_sdk.webrtc.config import StreamConfig +from inference_sdk.webrtc.datachannel import ChunkReassembler +from inference_sdk.webrtc.sources import StreamSource + +if TYPE_CHECKING: + from aiortc import RTCDataChannel, RTCPeerConnection + + +def _check_webrtc_dependencies(): + """Check if WebRTC dependencies are installed and provide helpful error message.""" + try: + import aiortc # noqa: F401 + import av # noqa: F401 + except ImportError as e: + raise ImportError( + "WebRTC dependencies are not installed.\n" + "Install them with: pip install inference-sdk[webrtc]\n" + "Or if installing from source: pip install aiortc>=1.9.0" + ) from e + + +logger = get_logger("webrtc.session") + + +class SessionState(Enum): + """WebRTC session lifecycle states.""" + + NOT_STARTED = "not_started" + STARTED = "started" + CLOSED = "closed" + + +@dataclass +class VideoMetadata: + """Metadata about a video frame received from WebRTC stream. + + This metadata is attached to each frame processed by the server + and can be used to track frame timing, synchronization, and + processing information. + + Attributes: + frame_id: Unique identifier for this frame in the stream + received_at: Timestamp when the server received the frame + pts: Presentation timestamp from the video stream (optional) + time_base: Time base for interpreting pts values (optional) + declared_fps: Declared/expected frames per second (optional) + measured_fps: Measured actual frames per second (optional) + """ + + frame_id: int + received_at: datetime + pts: Optional[int] = None + time_base: Optional[float] = None + declared_fps: Optional[float] = None + measured_fps: Optional[float] = None + + +class _VideoStream: + """Wrapper for video frame queue providing iterator interface.""" + + def __init__( + self, + session: "WebRTCSession", + frames: "Queue[Optional[tuple[np.ndarray, VideoMetadata]]]", + initial_frame_timeout: float = WEBRTC_INITIAL_FRAME_TIMEOUT, + ): + self._session = session + self._frames = frames + self._initial_frame_timeout = initial_frame_timeout + self._first_frame_received = False + + def __call__(self) -> Iterator[tuple[np.ndarray, VideoMetadata]]: + """Iterate over video frames with metadata. + + Automatically starts the session if not already started. + Yields tuples of (BGR numpy array, VideoMetadata) until the stream ends (None received) + or session is closed. + The metadata is extracted directly from the video frame (pts, time_base, etc.). + + Raises: + TimeoutError: If first frame not received within timeout period + """ + self._session._ensure_started() + while True: + # Check if session was closed (e.g., from a handler) + if self._session._state == SessionState.CLOSED: + break + + # Use timeout only for first frame to detect server not sending + timeout = ( + self._initial_frame_timeout if not self._first_frame_received else None + ) + + try: + frame_data = self._frames.get(timeout=timeout) + except queue.Empty: + raise TimeoutError( + f"No video frames received within {self._initial_frame_timeout}s timeout.\n" + "This likely means the server is not sending video.\n" + "Troubleshooting:\n" + " - Check that stream_output is configured in your StreamConfig\n" + " - Verify the workflow outputs match your configuration\n" + " - Ensure the server has WebRTC enabled and is processing frames" + ) + + if frame_data is None: + break + + self._first_frame_received = True + yield frame_data + + +class WebRTCSession: + """WebRTC session for streaming video and receiving inference results. + + This class manages the WebRTC peer connection, video streaming, + and data channel communication with the inference server. + + The session automatically starts on first use (e.g., calling run() or video()). + Call close() to cleanup resources, or rely on __del__ for automatic cleanup. + + Example: + session = client.webrtc.stream(source=source, workflow=workflow) + + @session.on_frame + def process_frame(frame, metadata): + cv2.imshow("Frame", frame) + if cv2.waitKey(1) & 0xFF == ord('q'): + session.close() + + session.run() # Auto-starts, auto-closes on exception + """ + + def __init__( + self, + api_url: str, + api_key: Optional[str], + source: StreamSource, + image_input_name: str, + workflow_config: dict, + stream_config: StreamConfig, + ) -> None: + """Initialize WebRTC session. + + Args: + api_url: Inference server API URL + api_key: API key for authentication + source: Stream source instance + image_input_name: Name of image input in workflow + workflow_config: Workflow configuration dict + stream_config: Stream configuration + """ + + self._state: SessionState = SessionState.NOT_STARTED + self._state_lock: threading.Lock = threading.Lock() + + self._api_url = api_url.rstrip("/") + self._api_key = api_key + self._source = source + self._image_input_name = image_input_name + self._workflow_config = workflow_config + self._config = stream_config + + # Internal state + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._loop_thread: Optional[threading.Thread] = None + self._pc: Optional["RTCPeerConnection"] = None + self._video_queue: "Queue[Optional[tuple[np.ndarray, VideoMetadata]]]" = Queue( + maxsize=WEBRTC_VIDEO_QUEUE_MAX_SIZE + ) + + # Callback handlers + self._frame_handlers: List[Callable] = [] + self._data_field_handlers: dict[str, List[Callable]] = {} + self._data_global_handler: Optional[Callable] = None + + # Chunk reassembly for binary messages + self._chunk_reassembler = ChunkReassembler() + + # Public APIs + self.video = _VideoStream(self, self._video_queue) + + def _init_connection(self) -> None: + """Initialize event loop, thread, and WebRTC connection.""" + # Start event loop in background thread + self._loop = asyncio.new_event_loop() + + def _run(loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + self._loop_thread = threading.Thread( + target=_run, args=(self._loop,), daemon=True + ) + self._loop_thread.start() + + # Initialize WebRTC connection + fut = asyncio.run_coroutine_threadsafe(self._init(), self._loop) + try: + fut.result() + except requests.exceptions.HTTPError as e: + if e.response.status_code == 404: + raise RuntimeError( + f"WebRTC endpoint not found at {self._api_url}/initialise_webrtc_worker.\n" + f"This API URL may not support WebRTC streaming.\n" + f"Troubleshooting:\n" + f" - For self-hosted inference, ensure the server is started with WebRTC enabled\n" + f" - For Roboflow Cloud, use a dedicated inference server URL (not serverless.roboflow.com)\n" + f" - Verify the --api-url parameter points to the correct server" + ) from e + else: + raise RuntimeError( + f"Failed to initialize WebRTC session (HTTP {e.response.status_code}).\n" + f"API URL: {self._api_url}\n" + f"Error: {e}" + ) from e + except Exception as e: + raise RuntimeError( + f"Failed to initialize WebRTC session: {e.__class__.__name__}: {e}\n" + f"API URL: {self._api_url}" + ) from e + + def _ensure_started(self) -> None: + """Ensure connection is started (thread-safe, idempotent).""" + with self._state_lock: + if self._state == SessionState.NOT_STARTED: + self._state = SessionState.STARTED + self._init_connection() + elif self._state == SessionState.CLOSED: + raise RuntimeError("Cannot use closed WebRTCSession") + + def _parse_video_metadata( + self, video_metadata_dict: Optional[dict] + ) -> Optional[VideoMetadata]: + """Parse video metadata from message dict. + + Args: + video_metadata_dict: Dictionary containing video metadata fields + + Returns: + VideoMetadata instance or None if parsing fails or dict is None + """ + if not video_metadata_dict: + return None + + try: + return VideoMetadata( + frame_id=video_metadata_dict["frame_id"], + received_at=datetime.fromisoformat(video_metadata_dict["received_at"]), + pts=video_metadata_dict.get("pts"), + time_base=video_metadata_dict.get("time_base"), + declared_fps=video_metadata_dict.get("declared_fps"), + measured_fps=video_metadata_dict.get("measured_fps"), + ) + except (KeyError, ValueError, TypeError) as e: + logger.warning(f"Failed to parse video_metadata: {e}") + return None + + def close(self) -> None: + """Close session and cleanup all resources. Idempotent - safe to call multiple times. + + This method closes the WebRTC peer connection, releases source resources + (webcam, video files, etc.), stops the event loop, and joins the background thread. + + It's safe to call this multiple times - subsequent calls are no-ops. + + Example: + session = client.webrtc.stream(source=source, workflow=workflow) + session.run() # Auto-starts and auto-closes on exception + session.close() # Explicit cleanup (or let __del__ handle it) + """ + with self._state_lock: + if self._state == SessionState.CLOSED: + return # Already closed, nothing to do + self._state = SessionState.CLOSED + + # Signal video iterator to stop by putting None sentinel + try: + self._video_queue.put_nowait(None) + except Exception: + pass # Queue might be full, but that's okay + + # Cleanup resources (nested finally ensures all cleanup steps execute) + try: + # Close peer connection + if self._loop and self._pc: + asyncio.run_coroutine_threadsafe(self._pc.close(), self._loop).result() + finally: + try: + # Cleanup source (webcam, video file, etc.) + if self._loop and self._source: + asyncio.run_coroutine_threadsafe( + self._source.cleanup(), self._loop + ).result() + finally: + # Stop event loop and join thread + if self._loop: + self._loop.call_soon_threadsafe(self._loop.stop) + if self._loop_thread: + self._loop_thread.join(timeout=WEBRTC_EVENT_LOOP_SHUTDOWN_TIMEOUT) + + def __enter__(self) -> "WebRTCSession": + """Enter context manager - returns self. + + Returns: + WebRTCSession: The session instance for use in with statement. + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager - automatically closes the session. + + Args: + exc_type: Exception type if an exception occurred, None otherwise. + exc_val: Exception value if an exception occurred, None otherwise. + exc_tb: Exception traceback if an exception occurred, None otherwise. + """ + self.close() + + def __del__(self) -> None: + """Cleanup if user forgot to close. Not guaranteed to run immediately.""" + try: + if self._state == SessionState.STARTED: + logger.warning( + "WebRTCSession was not properly closed. " + "Consider calling session.close() explicitly for immediate cleanup." + ) + self.close() + except Exception: + pass # Never raise from __del__ + + def wait(self, timeout: Optional[float] = None) -> None: + """Wait for session to complete. + + Blocks until the video stream ends (None received) or timeout expires. + Automatically starts the session if not already started. + + Args: + timeout: Maximum time to wait in seconds (None for indefinite) + + Raises: + TimeoutError: If timeout expires before stream ends + """ + self._ensure_started() + try: + while True: + frame_data = self._video_queue.get(timeout=timeout) + if frame_data is None: + break + except queue.Empty: + if timeout is not None: + raise TimeoutError( + f"WebRTC session wait() timed out after {timeout}s.\n" + "The video stream did not end within the timeout period." + ) + + def on_frame(self, callback: Callable) -> Callable: + """Decorator to register frame callback handlers. + + The registered handlers will be called for each video frame received + when using the run() method. Handlers must accept two parameters: + - frame: BGR numpy array (np.ndarray) + - metadata: Video metadata (VideoMetadata) extracted from the video frame + + Args: + callback: Callback function that accepts (frame, metadata) + + Returns: + The callback itself + + Examples: + @session.on_frame + def process_frame(frame: np.ndarray, metadata: VideoMetadata): + print(f"Frame {metadata.frame_id} - PTS: {metadata.pts}") + cv2.imshow("Frame", frame) + if cv2.waitKey(1) & 0xFF == ord('q'): + session.stop() + """ + self._frame_handlers.append(callback) + return callback + + def on_data(self, field_name: Optional[str] = None) -> Callable: + """Decorator to register data channel callback handlers. + + Can be used with or without parentheses: + @session.on_data # without parentheses (global handler) + @session.on_data() # with parentheses (global handler) + @session.on_data("field") # with field name (field-specific handler) + + Args: + field_name: If provided, handler receives only that field's value. + If None, handler receives entire serialized_output_data dict. + + Returns: + Decorator function or decorated function + + Examples: + # Global handler without parentheses + @session.on_data + def handle_all(data: dict, metadata: VideoMetadata): + print(f"All data: {data}") + + # Field-specific handler + @session.on_data("predictions") + def handle_predictions(data: dict, metadata: VideoMetadata): + print(f"Frame {metadata.frame_id}: {data}") + + # Field-specific handler (no metadata) + @session.on_data("predictions") + def handle_predictions(data: dict): + print(data) + + # Global handler with parentheses + @session.on_data() + def handle_all(data: dict, metadata: VideoMetadata): + print(f"All data: {data}") + """ + # Check if being used without parentheses: @session.on_data + # In this case, field_name is actually the function being decorated + if callable(field_name): + fn = field_name + self._data_global_handler = fn + return fn + + # Being used with parentheses: @session.on_data() or @session.on_data("field") + def decorator(fn: Callable) -> Callable: + if field_name is None: + self._data_global_handler = fn + else: + if field_name not in self._data_field_handlers: + self._data_field_handlers[field_name] = [] + self._data_field_handlers[field_name].append(fn) + return fn + + return decorator + + def run(self) -> None: + """Block and process frames until close() is called or stream ends. + + This method iterates over incoming video frames and invokes all + registered frame handlers for each frame. Automatically starts + the session if not already started. + + The session automatically closes when this method exits, whether + normally or due to an exception, ensuring resources are always + cleaned up. + + Blocks until either: + - close() is called (e.g., from a callback) + - The video stream ends naturally + - An exception occurs (session auto-closes, exception re-raised) + - KeyboardInterrupt (Ctrl+C) is received (session auto-closes) + + Data channel handlers are invoked automatically when data arrives, + independent of this method. + + Example: + session = client.webrtc.stream(source=source, workflow=workflow) + + @session.on_frame + def process(frame, metadata): + print(f"Frame {metadata.frame_id} - PTS: {metadata.pts}") + cv2.imshow("Frame", frame) + if cv2.waitKey(1) & 0xFF == ord('q'): + session.close() # Exits run() and cleans up + + session.run() # Auto-starts, auto-closes, blocks here + """ + with self: + for frame, metadata in self.video(): + # Invoke all registered frame handlers with both parameters + for handler in self._frame_handlers: + try: + handler(frame, metadata) + except Exception: + logger.warning("Error in frame handler", exc_info=True) + + def _invoke_data_handler( + self, handler: Callable, value: Any, metadata: Optional[VideoMetadata] + ) -> None: # noqa: ANN401 + """Invoke data handler with appropriate signature (auto-detect via introspection). + + Supports two signatures: + - handler(value, metadata) - receives both value and metadata + - handler(value) - receives only value + + Args: + handler: The handler callable to invoke + value: The data value to pass + metadata: Optional video metadata to pass + """ + try: + sig = inspect.signature(handler) + params = list(sig.parameters.values()) + + # Check number of parameters (excluding *args, **kwargs) + positional_params = [ + p + for p in params + if p.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + ] + + if len(positional_params) >= 2: + # Handler expects both value and metadata + handler(value, metadata) + else: + # Handler expects only value + handler(value) + except Exception: + # Fallback: try calling with just the value + try: + handler(value) + except Exception: + logger.exception(f"Failed to invoke handler {handler}") + raise + + async def _get_turn_config(self) -> Optional[dict]: + """Get TURN configuration from user-provided config. + + Priority order: + 1. User-provided config via StreamConfig.turn_server (highest priority) + 2. Skip TURN for localhost connections + 3. Return None if not provided + + Returns: + TURN configuration dict or None + """ + # 1. Use user-provided config if available + if self._config.turn_server: + logger.debug("Using user-provided TURN configuration") + return self._config.turn_server + + # 2. Skip TURN for localhost connections + if self._api_url.startswith(("http://localhost", "http://127.0.0.1")): + logger.debug("Skipping TURN for localhost connection") + return None + + # 3. No TURN config provided + logger.debug("No TURN configuration provided, proceeding without TURN server") + return None + + async def _init(self) -> None: + """Initialize WebRTC connection. + + Sets up peer connection, configures source, negotiates with server. + """ + # Check dependencies and import them + _check_webrtc_dependencies() + from aiortc import ( + RTCConfiguration, + RTCIceServer, + RTCPeerConnection, + RTCSessionDescription, + ) + from aiortc.contrib.media import MediaRelay + from av import VideoFrame + + # Fetch TURN configuration (auto-fetch or user-provided) + turn_config = await self._get_turn_config() + + # Create peer connection with TURN config if available + configuration = None + if turn_config: + ice = RTCIceServer( + urls=[turn_config.get("urls")], + username=turn_config.get("username"), + credential=turn_config.get("credential"), + ) + configuration = RTCConfiguration(iceServers=[ice]) + + pc = RTCPeerConnection(configuration=configuration) + relay = MediaRelay() + + # Setup video receiver for frames from server + @pc.on("track") + def _on_track(track): # noqa: ANN001 + subscribed = relay.subscribe(track) + + async def _reader(): + from aiortc.mediastreams import MediaStreamError + + while True: + try: + f: VideoFrame = await subscribed.recv() + except MediaStreamError: + # Remote stream finished normally + logger.info("Remote stream finished") + try: + self._video_queue.put_nowait(None) + except Exception: + pass + break + except Exception as e: + # Connection closed or track ended unexpectedly + logger.error( + f"WebRTC video track ended: {e.__class__.__name__}: {e}", + exc_info=True, + ) + try: + self._video_queue.put_nowait(None) + except Exception: + pass + break + img = f.to_ndarray(format="bgr24") + current_metadata = VideoMetadata( + frame_id=f.pts, + received_at=datetime.now(), + pts=f.pts, + time_base=f.time_base, + declared_fps=None, + measured_fps=None, + ) + # Backpressure: drop oldest frame if queue full + if self._video_queue.full(): + try: + _ = self._video_queue.get_nowait() + except Exception: + pass + try: + self._video_queue.put_nowait((img, current_metadata)) + except Exception: + pass + + asyncio.ensure_future(_reader()) + + # Setup data channel + ch = pc.createDataChannel("inference") + + # Setup data channel message handler + @ch.on("message") + def _on_data_message(message: Any) -> None: # noqa: ANN401 + try: + # Handle both bytes and str messages + if isinstance(message, bytes): + # Check if it's a chunked binary message + if len(message) >= 12: + try: + # Try to reassemble chunks + complete_payload, _ = self._chunk_reassembler.add_chunk( + message + ) + if complete_payload is None: + # Not all chunks received yet + return + # Parse the complete JSON from reassembled payload + message = complete_payload.decode("utf-8") + except (struct.error, ValueError): + # Not a chunked message, try to decode as regular UTF-8 + message = message.decode("utf-8") + else: + # Too short to be chunked, decode as regular UTF-8 + message = message.decode("utf-8") + + parsed_message = json.loads(message) + + # Extract video metadata if present (for data handlers) + metadata = self._parse_video_metadata( + parsed_message.get("video_metadata") + ) + + # Get serialized output data + serialized_data = parsed_message.get("serialized_output_data") + + # Call global handler if registered + if self._data_global_handler: + try: + self._invoke_data_handler( + self._data_global_handler, serialized_data, metadata + ) + except Exception: + logger.warning( + "Error calling global data handler", exc_info=True + ) + + # Route to field-specific handlers + if isinstance(serialized_data, dict): + for field_name, field_value in serialized_data.items(): + if field_name in self._data_field_handlers: + for handler in list(self._data_field_handlers[field_name]): + try: + self._invoke_data_handler( + handler, field_value, metadata + ) + except Exception: + logger.warning( + f"Error calling handler for field '{field_name}'", + exc_info=True, + ) + except json.JSONDecodeError: + logger.warning("Failed to parse data channel message as JSON") + + # Let source configure the peer connection + # (adds tracks for webcam/video/manual, or recvonly transceiver for RTSP) + await self._source.configure_peer_connection(pc) + + # Create offer and wait for ICE gathering + offer = await pc.createOffer() + await pc.setLocalDescription(offer) + + # Wait for ICE gathering to complete + while pc.iceGatheringState != "complete": + await asyncio.sleep(0.1) + + # Build server initialization payload + wf_conf: dict[str, Any] = { + "type": "WorkflowConfiguration", + "image_input_name": self._image_input_name, + "workflows_parameters": self._config.workflow_parameters, + } + wf_conf.update(self._workflow_config) + + payload = { + "api_key": self._api_key, + "workflow_configuration": wf_conf, + "webrtc_offer": { + "type": pc.localDescription.type, + "sdp": pc.localDescription.sdp, + }, + "webrtc_realtime_processing": self._config.realtime_processing, + "stream_output": self._config.stream_output, + "data_output": self._config.data_output, + } + + # Add TURN config if available (auto-fetched or user-provided) + if turn_config: + payload["webrtc_turn_config"] = turn_config + + # Add FPS if provided + if self._config.declared_fps: + payload["declared_fps"] = self._config.declared_fps + + # Merge source-specific parameters + # (rtsp_url for RTSP, declared_fps for webcam, etc.) + payload.update(self._source.get_initialization_params()) + + # Call server to initialize worker + url = f"{self._api_url}/initialise_webrtc_worker" + headers = {"Content-Type": "application/json"} + resp = requests.post(url, json=payload, headers=headers, timeout=90) + resp.raise_for_status() + ans: dict[str, Any] = resp.json() + + # Set remote description + answer = RTCSessionDescription(sdp=ans["sdp"], type=ans["type"]) + await pc.setRemoteDescription(answer) + + self._pc = pc diff --git a/inference_sdk/webrtc/sources.py b/inference_sdk/webrtc/sources.py new file mode 100644 index 0000000000..515f838a67 --- /dev/null +++ b/inference_sdk/webrtc/sources.py @@ -0,0 +1,381 @@ +"""Stream source abstractions for WebRTC SDK. + +This module defines the StreamSource interface and concrete implementations +for different video streaming sources (webcam, RTSP, video files, manual frames). +""" + +import asyncio +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +import av +import cv2 +import numpy as np +from aiortc import RTCPeerConnection, VideoStreamTrack +from av import VideoFrame + +from inference_sdk.http.errors import InvalidParameterError + + +class StreamSource(ABC): + """Base interface for all stream sources. + + A StreamSource is responsible for: + 1. Configuring the RTCPeerConnection (adding tracks or transceivers) + 2. Providing initialization parameters for the server + 3. Cleaning up resources when done + """ + + @abstractmethod + async def configure_peer_connection(self, pc: RTCPeerConnection) -> None: + """Configure the peer connection for this source type. + + This is where the source decides: + - Whether to add a local track (webcam, video file, manual) + - Whether to add a receive-only transceiver (RTSP) + - Any other peer connection configuration + + Args: + pc: The RTCPeerConnection to configure + """ + pass + + @abstractmethod + def get_initialization_params(self) -> Dict[str, Any]: + """Get parameters to send to server in /initialise_webrtc_worker payload. + + Returns: + Dictionary of parameters specific to this source type. + Examples: + - RTSP: {"rtsp_url": "rtsp://..."} + - Video file: {"video_path": "/path/to/file"} + - Webcam/Manual: {} (empty, no server-side source) + """ + pass + + async def cleanup(self) -> None: + """Cleanup resources when session ends. + + Default implementation does nothing. Override if cleanup is needed. + """ + pass + + +class _OpenCVVideoTrack(VideoStreamTrack): + """Base class for video tracks that use OpenCV capture. + + This consolidates common logic for webcam and video file tracks. + """ + + def __init__(self, source: Any, error_name: str): + """Initialize OpenCV video track. + + Args: + source: OpenCV VideoCapture source (int for webcam, str for file) + error_name: Human-readable name for error messages + """ + super().__init__() + self._cap = cv2.VideoCapture(source) + if not self._cap.isOpened(): + raise RuntimeError(f"Could not open {error_name}: {source}") + self._error_name = error_name + + async def recv(self) -> VideoFrame: # type: ignore[override] + """Read next frame from OpenCV capture.""" + ret, frame = self._cap.read() + if not ret or frame is None: + raise RuntimeError(f"Failed to read from {self._error_name}") + + return await self._frame_to_video(frame) + + async def _frame_to_video(self, frame: np.ndarray) -> VideoFrame: + """Convert numpy frame to VideoFrame with timestamp. + + Args: + frame: BGR numpy array (H, W, 3) uint8 + + Returns: + VideoFrame with proper timestamp + """ + vf = VideoFrame.from_ndarray(frame, format="bgr24") + vf.pts, vf.time_base = await self.next_timestamp() + return vf + + def get_declared_fps(self) -> Optional[float]: + """Get the declared FPS from the OpenCV capture.""" + fps = self._cap.get(cv2.CAP_PROP_FPS) + return float(fps) if fps and fps > 0 else None + + def release(self) -> None: + """Release the OpenCV capture.""" + try: + self._cap.release() + except Exception: + pass + + +class _WebcamVideoTrack(_OpenCVVideoTrack): + """aiortc VideoStreamTrack that reads frames from OpenCV webcam.""" + + def __init__(self, device_id: int, resolution: Optional[tuple[int, int]]): + super().__init__(device_id, "webcam device") + + if resolution: + self._cap.set(cv2.CAP_PROP_FRAME_WIDTH, resolution[0]) + self._cap.set(cv2.CAP_PROP_FRAME_HEIGHT, resolution[1]) + + +class WebcamSource(StreamSource): + """Stream source for local webcam/USB camera. + + This source creates a local video track that captures frames from + a webcam device using OpenCV and sends them to the server. + """ + + def __init__( + self, device_id: int = 0, resolution: Optional[tuple[int, int]] = None + ): + """Initialize webcam source. + + Args: + device_id: Camera device index (0 for default camera) + resolution: Optional (width, height) tuple to set camera resolution + """ + self.device_id = device_id + self.resolution = resolution + self._track: Optional[_WebcamVideoTrack] = None + self._declared_fps: Optional[float] = None + + async def configure_peer_connection(self, pc: RTCPeerConnection) -> None: + """Create webcam video track and add it to the peer connection.""" + # Create local video track that reads from OpenCV + self._track = _WebcamVideoTrack(self.device_id, self.resolution) + + # Capture FPS for server + self._declared_fps = self._track.get_declared_fps() + + # Add track to send video + pc.addTrack(self._track) + + def get_initialization_params(self) -> Dict[str, Any]: + """Return FPS if available.""" + params = {} + if self._declared_fps: + params["declared_fps"] = self._declared_fps + return params + + async def cleanup(self) -> None: + """Release webcam resources.""" + if self._track: + self._track.release() + + +class RTSPSource(StreamSource): + """Stream source for RTSP camera streams. + + This source doesn't create a local track - instead, the server + captures the RTSP stream and sends processed video back to the client. + """ + + def __init__(self, url: str): + """Initialize RTSP source. + + Args: + url: RTSP URL (e.g., "rtsp://camera.local/stream") + Credentials can be included: "rtsp://user:pass@host/stream" + """ + self.url = url + self._validate_url() + + def _validate_url(self) -> None: + """Validate that the URL is a valid RTSP URL.""" + if not self.url.startswith(("rtsp://", "rtsps://")): + raise InvalidParameterError( + f"Invalid RTSP URL: {self.url}. Must start with rtsp:// or rtsps://" + ) + + async def configure_peer_connection(self, pc: RTCPeerConnection) -> None: + """Add receive-only video transceiver (server sends video to us).""" + # Don't create a local track - we're receiving video from server + # Add receive-only transceiver + pc.addTransceiver("video", direction="recvonly") + + def get_initialization_params(self) -> Dict[str, Any]: + """Return RTSP URL for server to capture.""" + # Server needs to know the RTSP URL to capture + return {"rtsp_url": self.url} + + +class _VideoFileTrack(VideoStreamTrack): + """aiortc VideoStreamTrack that reads frames from a video file using PyAV. + + Uses PyAV instead of OpenCV to preserve original video timestamps and time_base. + """ + + def __init__(self, path: str): + super().__init__() + try: + self._container = av.open(path) + except Exception as e: + raise RuntimeError(f"Could not open video file: {path}") from e + + if not self._container.streams.video: + raise RuntimeError(f"No video stream found in: {path}") + + self._stream = self._container.streams.video[0] + self._stream.thread_type = "AUTO" # Enable multi-threaded decoding + self._decoder = self._container.decode(self._stream) + + async def recv(self) -> VideoFrame: # type: ignore[override] + """Read next frame from video file with aiortc pacing.""" + try: + frame = next(self._decoder) + # Call next_timestamp() for pacing (asyncio.sleep), but keep original timing + # This preserves the video's original pts/time_base while preventing frames + # from decoding too fast + await self.next_timestamp() + return frame + except StopIteration: + # End of file - use Exception (not RuntimeError) for EOF + raise Exception("End of video file") + + def get_declared_fps(self) -> Optional[float]: + """Get the FPS from the video stream.""" + if self._stream.average_rate: + return float(self._stream.average_rate) + return None + + def release(self) -> None: + """Release the PyAV container.""" + try: + if hasattr(self, "_container") and self._container: + self._container.close() + except Exception: + pass + + +class VideoFileSource(StreamSource): + """Stream source for video files. + + This source creates a local video track that reads frames from + a video file and sends them to the server. + """ + + def __init__(self, path: str): + """Initialize video file source. + + Args: + path: Path to video file (any format supported by PyAV/FFmpeg) + """ + self.path = path + self._track: Optional[_VideoFileTrack] = None + self._declared_fps: Optional[float] = None + + async def configure_peer_connection(self, pc: RTCPeerConnection) -> None: + """Create video file track and add it to the peer connection.""" + # Create track that reads from video file + self._track = _VideoFileTrack(self.path) + + # Capture FPS for server + self._declared_fps = self._track.get_declared_fps() + + # Add track to send video + pc.addTrack(self._track) + + def get_initialization_params(self) -> Dict[str, Any]: + """Return metadata about video source.""" + params = {"video_source": "file"} + if self._declared_fps: + params["declared_fps"] = self._declared_fps + return params + + async def cleanup(self) -> None: + """Release video file resources.""" + if self._track: + self._track.release() + + +# Configuration constants for manual source +MANUAL_SOURCE_QUEUE_MAX_SIZE = 10 # maximum number of queued frames for manual source + + +class ManualSource(StreamSource): + """Stream source for manually sent frames. + + This source allows the user to programmatically send frames + to be processed by the workflow using the send() method. + """ + + def __init__(self): + """Initialize manual source.""" + self._track: Optional[_ManualTrack] = None + + async def configure_peer_connection(self, pc: RTCPeerConnection) -> None: + """Create manual track and add it to the peer connection.""" + # Create special track that accepts programmatic frames + self._track = _ManualTrack() + pc.addTrack(self._track) + + def get_initialization_params(self) -> Dict[str, Any]: + """Return manual mode flag.""" + return {"manual_mode": True} + + def send(self, frame: np.ndarray) -> None: + """Send a frame to be processed by the workflow. + + Args: + frame: BGR numpy array (H, W, 3) uint8 + + Raises: + RuntimeError: If session not started + """ + if not self._track: + raise RuntimeError("Session not started. Use within 'with' context.") + self._track.queue_frame(frame) + + +class _ManualTrack(VideoStreamTrack): + """aiortc VideoStreamTrack that accepts programmatically queued frames.""" + + def __init__(self): + super().__init__() + self._queue: asyncio.Queue[Optional[np.ndarray]] = asyncio.Queue( + maxsize=MANUAL_SOURCE_QUEUE_MAX_SIZE + ) + + async def recv(self) -> VideoFrame: # type: ignore[override] + """Wait for next frame to be queued.""" + # Wait for next frame to be queued + frame = await self._queue.get() + if frame is None: + raise Exception("Manual track stopped") + + return await self._frame_to_video(frame) + + async def _frame_to_video(self, frame: np.ndarray) -> VideoFrame: + """Convert numpy frame to VideoFrame with timestamp. + + Args: + frame: BGR numpy array (H, W, 3) uint8 + + Returns: + VideoFrame with proper timestamp + """ + vf = VideoFrame.from_ndarray(frame, format="bgr24") + vf.pts, vf.time_base = await self.next_timestamp() + return vf + + def queue_frame(self, frame: np.ndarray) -> None: + """Queue a frame to be sent (called from main thread). + + If the queue is full, the oldest frame is dropped. + """ + try: + self._queue.put_nowait(frame) + except asyncio.QueueFull: + # Drop oldest frame + try: + self._queue.get_nowait() + self._queue.put_nowait(frame) + except Exception: + pass diff --git a/requirements/requirements.sdk.webrtc.txt b/requirements/requirements.sdk.webrtc.txt new file mode 100644 index 0000000000..fe3a259499 --- /dev/null +++ b/requirements/requirements.sdk.webrtc.txt @@ -0,0 +1 @@ +aiortc>=1.9.0 diff --git a/tests/inference_sdk/e2e_tests/__init__.py b/tests/inference_sdk/e2e_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/inference_sdk/e2e_tests/webrtc/__init__.py b/tests/inference_sdk/e2e_tests/webrtc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/inference_sdk/e2e_tests/webrtc/conftest.py b/tests/inference_sdk/e2e_tests/webrtc/conftest.py new file mode 100644 index 0000000000..d47eaa1a5b --- /dev/null +++ b/tests/inference_sdk/e2e_tests/webrtc/conftest.py @@ -0,0 +1,108 @@ +"""Fixtures for WebRTC end-to-end tests with real inference server. + +These fixtures support slow tests that validate the full WebRTC stack +with a real inference server. +""" + +import multiprocessing +import threading +import time +from functools import partial + +import pytest +import requests +import uvicorn + + +@pytest.fixture(scope="session") +def inference_server(): + """Start real inference server for end-to-end tests. + + This fixture starts the full inference server stack: + - HTTP API (uvicorn) + - Stream Manager (WebRTC worker) + + The server runs in separate processes and is cleaned up after tests. + """ + # Import server components + from inference.core.cache import cache + from inference.core.env import MAX_ACTIVE_MODELS + from inference.core.interfaces.http.http_api import HttpInterface + from inference.core.interfaces.stream_manager.manager_app.app import start + from inference.core.managers.active_learning import ( + BackgroundTaskActiveLearningManager, + ) + from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache + from inference.core.registries.roboflow import RoboflowModelRegistry + from inference.models.utils import ROBOFLOW_MODEL_TYPES + + # Setup model manager (similar to debugrun.py) + model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) + model_manager = BackgroundTaskActiveLearningManager( + model_registry=model_registry, cache=cache + ) + model_manager = WithFixedSizeCache(model_manager, max_size=MAX_ACTIVE_MODELS) + model_manager.init_pingback() + + # Create HTTP interface + interface = HttpInterface(model_manager) + app = interface.app + + # Start stream manager process (needs separate process) + stream_manager_process = multiprocessing.Process( + target=partial(start, expected_warmed_up_pipelines=0), + daemon=True, + ) + stream_manager_process.start() + + # Start HTTP server in thread (avoids pickle issues) + config = uvicorn.Config(app, host="127.0.0.1", port=9001, log_level="error") + server = uvicorn.Server(config) + + def run_server(): + import asyncio + asyncio.run(server.serve()) + + server_thread = threading.Thread(target=run_server, daemon=True) + server_thread.start() + + # Wait for server to be ready + server_url = "http://127.0.0.1:9001" + max_wait = 30 # seconds + start_time = time.time() + server_ready = False + + while time.time() - start_time < max_wait: + try: + resp = requests.get(server_url, timeout=2) + if resp.status_code in [200, 404]: # Server responding + server_ready = True + break + except (requests.ConnectionError, requests.Timeout): + time.sleep(0.5) + + if not server_ready: + server_process.terminate() + stream_manager_process.terminate() + raise TimeoutError( + f"Inference server failed to start within {max_wait} seconds" + ) + + print(f"\nโœ“ Inference server ready at {server_url}") + + # Yield server URL for tests + yield server_url + + # Teardown: terminate processes + print("\nโœ“ Shutting down inference server...") + + # Shutdown HTTP server + server.should_exit = True + + # Terminate stream manager process + stream_manager_process.terminate() + stream_manager_process.join(timeout=5) + + # Force kill if still alive + if stream_manager_process.is_alive(): + stream_manager_process.kill() diff --git a/tests/inference_sdk/e2e_tests/webrtc/test_video_file_e2e.py b/tests/inference_sdk/e2e_tests/webrtc/test_video_file_e2e.py new file mode 100644 index 0000000000..b797b30301 --- /dev/null +++ b/tests/inference_sdk/e2e_tests/webrtc/test_video_file_e2e.py @@ -0,0 +1,185 @@ +"""End-to-end tests for WebRTC SDK with real inference server. + +These tests start a real inference server and validate the full WebRTC stack: +- Real HTTP API +- Real WebRTC signaling +- Real workflow execution +- Real data channel communication + +Tests are marked with @pytest.mark.slow and can be skipped: + pytest -m "not slow" # Skip slow tests + pytest -m slow # Run only slow tests +""" + +import os +import numpy as np +import pytest + +from inference_sdk import InferenceHTTPClient +from inference_sdk.webrtc import StreamConfig, VideoFileSource + + +# Simple passthrough workflow - uses absolute_static_crop to wrap image +# This validates WebRTC connection and workflow execution without transformation +# The crop step is necessary to wrap the input in WorkflowImageData format for video streaming +PASSTHROUGH_WORKFLOW = { + "version": "1.0", + "inputs": [{"type": "InferenceImage", "name": "image"}], + "steps": [ + { + "type": "roboflow_core/absolute_static_crop@v1", + "name": "absolute_static_crop", + "images": "$inputs.image", + "x_center": 40, + "y_center": 40, + "width": 80, + "height": 80, + } + ], + "outputs": [ + { + "type": "JsonField", + "name": "absolute_static_crop", + "coordinates_system": "own", + "selector": "$steps.absolute_static_crop.crops", + } + ], +} + + +@pytest.mark.slow +def test_video_file_e2e_with_passthrough_workflow(inference_server): + """Full end-to-end test with real server and passthrough workflow. + + This test validates: + 1. Real inference server starts and responds + 2. WebRTC session establishes successfully + 3. Video frames are sent from file to server + 4. Workflow processes frames (passthrough - no transformation) + 5. Processed frames are received back + 6. Session cleanup works correctly + + Uses: + - Real HTTP API + - Real WebRTC signaling (aiortc) + - Real workflow execution + - Real video file (no mocking needed) + """ + # Create real client pointing to test server + client = InferenceHTTPClient(api_url=inference_server, api_key="test-key") + + # Path to test video file + test_video_path = os.path.join( + os.path.dirname(__file__), + "../../../inference/unit_tests/core/interfaces/assets/example_video.mp4" + ) + test_video_path = os.path.abspath(test_video_path) + + # Create video file source + source = VideoFileSource(path=test_video_path) + + # Configure to receive processed video stream + config = StreamConfig( + stream_output=["absolute_static_crop"], # Receive cropped frames + data_output=[], # No data channel needed for this test + realtime_processing=True, + ) + + # Start WebRTC session with inline workflow spec + with client.webrtc.stream( + source=source, + workflow=PASSTHROUGH_WORKFLOW, # Inline spec - no workspace needed! + image_input="image", + config=config, + ) as session: + # Collect some processed frames (session will auto-start on first use) + frames_received = [] + for i, (frame, metadata) in enumerate(session.video()): + frames_received.append((frame, metadata)) + if i >= 2: # Get 3 frames + break + + # Validate we received processed frames + assert len(frames_received) == 3, "Should receive 3 processed frames" + + for frame, metadata in frames_received: + # Frames should be numpy arrays + assert isinstance(frame, np.ndarray), "Frame should be numpy array" + + # Metadata should be VideoMetadata + assert metadata is not None, "Should have metadata" + assert metadata.frame_id is not None, "Should have frame_id" + + # Should be BGR format (H, W, 3) - cropped to 80x80 + assert frame.shape == ( + 80, + 80, + 3, + ), "Frame should be 80x80x3 (cropped)" + + # Should be uint8 + assert frame.dtype == np.uint8, "Frame should be uint8" + + # Session should cleanup successfully (context manager exit) + print("\nโœ“ E2E test passed: Real server + WebRTC + Grayscale workflow") + + +@pytest.mark.slow +def test_video_file_e2e_with_data_channel(inference_server): + """Test data channel output with real server. + + Validates that workflow outputs can be received via data channel + in addition to video stream. + """ + client = InferenceHTTPClient(api_url=inference_server, api_key="test-key") + + # Path to test video file + test_video_path = os.path.join( + os.path.dirname(__file__), + "../../../inference/unit_tests/core/interfaces/assets/example_video.mp4" + ) + test_video_path = os.path.abspath(test_video_path) + + source = VideoFileSource(path=test_video_path) + + # Request both video and data output + config = StreamConfig( + stream_output=["absolute_static_crop"], + data_output=["absolute_static_crop"], # Also receive via data channel + realtime_processing=True, + ) + + data_messages_received = [] + + with client.webrtc.stream( + source=source, + workflow=PASSTHROUGH_WORKFLOW, + image_input="image", + config=config, + ) as session: + # Register data channel handler + @session.on_data("image_output") + def handle_image_data(data): + data_messages_received.append(data) + + # Receive a few frames to trigger data channel messages + frame_count = 0 + for frame in session.video(): + frame_count += 1 + if frame_count >= 3: + break + + # Give data channel time to deliver messages + import time + + time.sleep(0.5) + + # Validate we received data channel messages + # Note: Depending on server implementation, we might receive data + assert isinstance( + data_messages_received, list + ), "Should receive data channel messages" + + print( + f"\nโœ“ Data channel test passed: Received {len(data_messages_received)} messages" + ) diff --git a/tests/inference_sdk/integration_tests/__init__.py b/tests/inference_sdk/integration_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/inference_sdk/integration_tests/webrtc/__init__.py b/tests/inference_sdk/integration_tests/webrtc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/inference_sdk/integration_tests/webrtc/conftest.py b/tests/inference_sdk/integration_tests/webrtc/conftest.py new file mode 100644 index 0000000000..127f0f42f0 --- /dev/null +++ b/tests/inference_sdk/integration_tests/webrtc/conftest.py @@ -0,0 +1,235 @@ +"""Shared fixtures for WebRTC integration tests.""" + +import asyncio +import json +from pathlib import Path +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import numpy as np +import pytest +from av import VideoFrame + + +@pytest.fixture +def mock_video_capture(): + """Mock cv2.VideoCapture to return synthetic frames.""" + + class MockVideoCapture: + def __init__(self, device_id): + self.device_id = device_id + self._frame_count = 0 + self._is_opened = True + self._width = 640 + self._height = 480 + self._fps = 30.0 + + def isOpened(self): + return self._is_opened + + def read(self): + if not self._is_opened or self._frame_count >= 100: # Limit frames + return False, None + + # Generate synthetic BGR frame + frame = np.random.randint(0, 255, (self._height, self._width, 3), dtype=np.uint8) + self._frame_count += 1 + return True, frame + + def set(self, prop, value): + if prop == 3: # CAP_PROP_FRAME_WIDTH + self._width = int(value) + elif prop == 4: # CAP_PROP_FRAME_HEIGHT + self._height = int(value) + return True + + def get(self, prop): + if prop == 3: # CAP_PROP_FRAME_WIDTH + return self._width + elif prop == 4: # CAP_PROP_FRAME_HEIGHT + return self._height + elif prop == 5: # CAP_PROP_FPS + return self._fps + return 0 + + def release(self): + self._is_opened = False + + with patch("cv2.VideoCapture", MockVideoCapture): + yield MockVideoCapture + + +@pytest.fixture +def mock_rtc_peer_connection(): + """Mock RTCPeerConnection and related WebRTC components.""" + + class MockRTCPeerConnection: + def __init__(self, configuration=None): + self.configuration = configuration + self.iceGatheringState = "complete" + self.localDescription = Mock() + self.localDescription.type = "offer" + self.localDescription.sdp = "mock-sdp-offer" + self._tracks = [] + self._transceivers = [] + self._data_channels = {} + self._track_handlers = [] + + def addTrack(self, track): + self._tracks.append(track) + + def addTransceiver(self, kind, direction=None): + transceiver = Mock() + transceiver.kind = kind + transceiver.direction = direction + self._transceivers.append(transceiver) + return transceiver + + def createDataChannel(self, label): + channel = MockRTCDataChannel(label) + self._data_channels[label] = channel + return channel + + async def createOffer(self): + return self.localDescription + + async def setLocalDescription(self, desc): + self.localDescription = desc + + async def setRemoteDescription(self, desc): + pass + + async def close(self): + pass + + def on(self, event): + """Decorator for event handlers.""" + def decorator(func): + if event == "track": + self._track_handlers.append(func) + return func + return decorator + + def simulate_incoming_track(self, num_frames=5): + """Helper to simulate incoming video track for testing.""" + mock_track = MockVideoTrack(num_frames=num_frames) + for handler in self._track_handlers: + handler(mock_track) + return mock_track + + class MockRTCDataChannel: + def __init__(self, label): + self.label = label + self._message_handlers = [] + self.readyState = "open" + + def on(self, event): + """Decorator for event handlers.""" + def decorator(func): + if event == "message": + self._message_handlers.append(func) + return func + return decorator + + def send_message(self, message): + """Helper to simulate incoming message for testing.""" + for handler in self._message_handlers: + handler(message) + + class MockVideoTrack: + """Mock video track that provides frames.""" + def __init__(self, num_frames=5): + self.num_frames = num_frames + self._frame_idx = 0 + + async def recv(self): + if self._frame_idx >= self.num_frames: + raise Exception("Track ended") + + # Create synthetic video frame + arr = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + frame = VideoFrame.from_ndarray(arr, format="bgr24") + frame.pts = self._frame_idx + frame.time_base = "1/30" + self._frame_idx += 1 + return frame + + with patch("aiortc.RTCPeerConnection", MockRTCPeerConnection), \ + patch("aiortc.contrib.media.MediaRelay") as mock_relay: + + # Mock MediaRelay to pass through the track + mock_relay_instance = Mock() + mock_relay_instance.subscribe = lambda track: track + mock_relay.return_value = mock_relay_instance + + yield MockRTCPeerConnection + + +@pytest.fixture +def mock_server_endpoints(): + """Mock server HTTP endpoints for worker initialization.""" + + with patch("requests.post") as mock_post: + + # Mock worker initialization response + def post_side_effect(url, **kwargs): + response = Mock() + if "initialise_webrtc_worker" in url: + response.json.return_value = { + "sdp": "mock-sdp-answer", + "type": "answer" + } + response.raise_for_status = Mock() + else: + response.raise_for_status = Mock(side_effect=Exception("Not found")) + return response + + mock_post.side_effect = post_side_effect + + yield {"post": mock_post} + + +@pytest.fixture +def test_video_path(): + """Path to the test video file.""" + return Path(__file__).parent / "data" / "test_video.mp4" + + +@pytest.fixture +def sample_workflow_config(): + """Sample workflow configuration for testing.""" + return { + "workspace_name": "test-workspace", + "workflow_id": "test-workflow" + } + + +@pytest.fixture +def sample_stream_config(): + """Sample stream configuration for testing.""" + from inference_sdk.webrtc import StreamConfig + + return StreamConfig( + stream_output=["image_output"], + data_output=["predictions"], + realtime_processing=True + ) + + +@pytest.fixture +def mock_inference_client(): + """Mock InferenceHTTPClient for testing.""" + client = Mock() + client.api_url = "http://test-server.com" + client.api_key = "test-api-key" + return client + + +@pytest.fixture +def enable_all_mocks(mock_video_capture, mock_rtc_peer_connection, mock_server_endpoints): + """Convenience fixture to enable all common mocks at once.""" + return { + "video_capture": mock_video_capture, + "peer_connection": mock_rtc_peer_connection, + "server_endpoints": mock_server_endpoints + } diff --git a/tests/inference_sdk/integration_tests/webrtc/data/test_video.mp4 b/tests/inference_sdk/integration_tests/webrtc/data/test_video.mp4 new file mode 100644 index 0000000000..adba27180e Binary files /dev/null and b/tests/inference_sdk/integration_tests/webrtc/data/test_video.mp4 differ diff --git a/tests/inference_sdk/integration_tests/webrtc/test_video_file_integration.py b/tests/inference_sdk/integration_tests/webrtc/test_video_file_integration.py new file mode 100644 index 0000000000..abfb3267a1 --- /dev/null +++ b/tests/inference_sdk/integration_tests/webrtc/test_video_file_integration.py @@ -0,0 +1,264 @@ +"""Integration tests for VideoFileSource. + +These tests use a real test video file but mock the server endpoints +and WebRTC connection to enable testing without a running server. +""" + +import time +from pathlib import Path + +import cv2 +import numpy as np +import pytest + +from inference_sdk.webrtc import StreamConfig, VideoFileSource, WebRTCSession + + +def test_video_file_session_basic( + enable_all_mocks, + test_video_path, + sample_workflow_config, + sample_stream_config +): + """Test basic video file session with real file. + + Validates that: + 1. Video file can be opened + 2. Session starts successfully + 3. Video properties are detected correctly + """ + assert test_video_path.exists(), f"Test video not found: {test_video_path}" + + source = VideoFileSource(str(test_video_path)) + + with WebRTCSession( + api_url="http://test-server.com", + api_key="test-key", + source=source, + image_input_name="image", + workflow_config=sample_workflow_config, + stream_config=sample_stream_config + ) as session: + session._ensure_started() + # Verify track was created + assert source._track is not None + assert source._track._container is not None + assert source._track._stream is not None + + +def test_video_file_fps_detection( + enable_all_mocks, + test_video_path, + sample_workflow_config, + sample_stream_config +): + """Test that video file FPS is detected correctly. + + Validates that FPS from the video file is read by the track. + """ + source = VideoFileSource(str(test_video_path)) + + with WebRTCSession( + api_url="http://test-server.com", + api_key="test-key", + source=source, + image_input_name="image", + workflow_config=sample_workflow_config, + stream_config=sample_stream_config + ) as session: + session._ensure_started() + # Verify track was created and FPS was detected + assert source._track is not None + fps = source._track.get_declared_fps() + assert fps is not None, "Should detect FPS from video file" + assert fps > 0, "FPS should be positive" + + +def test_video_file_with_stream_config_variations( + enable_all_mocks, + test_video_path, + sample_workflow_config +): + """Test video file with different StreamConfig options. + + Validates that different configuration options are properly handled. + """ + # Test with data output only (no stream output) + config_data_only = StreamConfig( + stream_output=[], + data_output=["results"], + realtime_processing=False + ) + + source = VideoFileSource(str(test_video_path)) + + with WebRTCSession( + api_url="http://test-server.com", + api_key="test-key", + source=source, + image_input_name="image", + workflow_config=sample_workflow_config, + stream_config=config_data_only + ) as session: + # Session should start successfully even without stream output + assert session._config.stream_output == [] + assert session._config.data_output == ["results"] + assert session._config.realtime_processing is False + + +def test_video_file_with_workflow_parameters( + enable_all_mocks, + test_video_path, + sample_workflow_config +): + """Test video file with workflow parameters. + + Validates that workflow parameters are included in the configuration. + """ + config = StreamConfig( + stream_output=["image_output"], + data_output=["predictions"], + workflow_parameters={ + "confidence_threshold": 0.5, + "iou_threshold": 0.3 + } + ) + + source = VideoFileSource(str(test_video_path)) + + with WebRTCSession( + api_url="http://test-server.com", + api_key="test-key", + source=source, + image_input_name="image", + workflow_config=sample_workflow_config, + stream_config=config + ) as session: + # Verify workflow parameters are set + assert session._config.workflow_parameters == { + "confidence_threshold": 0.5, + "iou_threshold": 0.3 + } + + +def test_video_file_cleanup( + enable_all_mocks, + test_video_path, + sample_workflow_config, + sample_stream_config +): + """Test that video file resources are properly cleaned up. + + Validates that PyAV container is released on session exit. + """ + source = VideoFileSource(str(test_video_path)) + + with WebRTCSession( + api_url="http://test-server.com", + api_key="test-key", + source=source, + image_input_name="image", + workflow_config=sample_workflow_config, + stream_config=sample_stream_config + ) as session: + session._ensure_started() + track = source._track + assert track._container is not None, "Container should exist during session" + # PyAV containers don't have a simple "is_open" check, but we can verify it exists + + # After exiting context, container should be closed (we verify cleanup was called) + # Note: PyAV containers don't expose a simple "is_closed" property, + # but attempting to use them after close will raise an error + + +def test_video_file_real_properties(test_video_path): + """Test reading actual properties from the test video file. + + This test doesn't mock VideoCapture to verify the actual test video + has expected properties. + """ + assert test_video_path.exists(), f"Test video not found: {test_video_path}" + + cap = cv2.VideoCapture(str(test_video_path)) + assert cap.isOpened(), "Should be able to open test video" + + # Read properties + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + # Verify expected properties + assert width == 640, f"Expected width 640, got {width}" + assert height == 480, f"Expected height 480, got {height}" + assert fps == 30.0, f"Expected FPS 30, got {fps}" + assert frame_count == 10, f"Expected 10 frames, got {frame_count}" + + # Read all frames + frames_read = 0 + while True: + ret, frame = cap.read() + if not ret: + break + frames_read += 1 + assert frame.shape == (480, 640, 3), "Frame should be 480x640x3" + + assert frames_read == 10, f"Should read 10 frames, got {frames_read}" + + cap.release() + + +def test_video_file_with_data_channel( + enable_all_mocks, + test_video_path, + sample_workflow_config +): + """Test video file processing with data channel output. + + Simulates a batch processing scenario where results are collected + via data channel. + """ + config = StreamConfig( + data_output=["analysis_results"], + stream_output=[] # No video output needed for batch processing + ) + + source = VideoFileSource(str(test_video_path)) + results = [] + + session = WebRTCSession( + api_url="http://test-server.com", + api_key="test-key", + source=source, + image_input_name="image", + workflow_config=sample_workflow_config, + stream_config=config + ) + + # Register handler to collect results + @session.on_data("analysis_results") + def collect_results(data): + results.append(data) + + session._ensure_started() + + try: + + # Get the data channel + data_channel = session._pc._data_channels["inference"] + + # Simulate receiving results for each frame + for i in range(5): + data_channel.send_message( + f'{{"serialized_output_data": {{"analysis_results": {{"frame": {i}, "detections": []}}}}}}' + ) + + # Give handlers time to process + time.sleep(0.1) + + # Validate results were collected + assert len(results) == 5 + for i, result in enumerate(results): + assert result["frame"] == i + finally: + session.close() diff --git a/tests/inference_sdk/unit_tests/webrtc/__init__.py b/tests/inference_sdk/unit_tests/webrtc/__init__.py new file mode 100644 index 0000000000..a82323d802 --- /dev/null +++ b/tests/inference_sdk/unit_tests/webrtc/__init__.py @@ -0,0 +1 @@ +"""Unit tests for WebRTC SDK.""" diff --git a/tests/inference_sdk/unit_tests/webrtc/test_session_lifecycle.py b/tests/inference_sdk/unit_tests/webrtc/test_session_lifecycle.py new file mode 100644 index 0000000000..87029912ea --- /dev/null +++ b/tests/inference_sdk/unit_tests/webrtc/test_session_lifecycle.py @@ -0,0 +1,342 @@ +"""Unit tests for WebRTC session lifecycle management.""" + +import pytest +from datetime import datetime +from unittest.mock import MagicMock, patch +import numpy as np + +from inference_sdk.webrtc.session import VideoMetadata, WebRTCSession, SessionState + + +@pytest.fixture +def mock_session(): + """Create a mock WebRTCSession instance without actually initializing WebRTC.""" + with patch("inference_sdk.webrtc.session._check_webrtc_dependencies"): + session = WebRTCSession( + api_url="http://localhost:9001", + api_key="test_key", + source=MagicMock(), + image_input_name="image", + workflow_config={}, + stream_config=MagicMock(), + ) + return session + + +class TestSessionLifecycle: + """Tests for session lifecycle (creation, starting, closing).""" + + def test_session_starts_in_not_started_state(self, mock_session): + """Test that session is created in not_started state.""" + assert mock_session._state == SessionState.NOT_STARTED + + def test_close_is_idempotent(self, mock_session): + """Test that close() can be called multiple times safely.""" + mock_session._state = SessionState.STARTED # Simulate started state + mock_session.close() + assert mock_session._state == SessionState.CLOSED + + # Second call should be a no-op + mock_session.close() + assert mock_session._state == SessionState.CLOSED + + def test_ensure_started_changes_state(self, mock_session): + """Test that _ensure_started() transitions from not_started to started.""" + with patch.object(mock_session, "_init_connection"): + assert mock_session._state == SessionState.NOT_STARTED + mock_session._ensure_started() + assert mock_session._state == SessionState.STARTED + + def test_ensure_started_is_idempotent(self, mock_session): + """Test that _ensure_started() can be called multiple times.""" + with patch.object(mock_session, "_init_connection") as mock_init: + mock_session._ensure_started() + mock_session._ensure_started() + mock_session._ensure_started() + + # _init_connection should only be called once + assert mock_init.call_count == 1 + + def test_ensure_started_raises_on_closed_session(self, mock_session): + """Test that _ensure_started() raises error if session is closed.""" + mock_session._state = SessionState.CLOSED + + with pytest.raises(RuntimeError, match="Cannot use closed WebRTCSession"): + mock_session._ensure_started() + + +class TestRunMethod: + """Tests for run() method and exception handling.""" + + def test_run_auto_starts_session(self, mock_session): + """Test that run() automatically starts the session.""" + with patch.object(mock_session, "_ensure_started") as mock_ensure: + # Put a frame and immediately close + @mock_session.on_frame + def handler(frame, metadata): + mock_session.close() + + test_frame = np.zeros((100, 100, 3), dtype=np.uint8) + test_metadata = VideoMetadata(frame_id=1, received_at=datetime.now()) + mock_session._video_queue.put((test_frame, test_metadata)) + mock_session._state = SessionState.STARTED + + mock_session.run() + + # Should have called _ensure_started + mock_ensure.assert_called_once() + + def test_run_stops_when_close_called(self, mock_session): + """Test that run() stops when close() is called from handler.""" + frame_count = [] + + @mock_session.on_frame + def count_frames(frame, metadata): + frame_count.append(1) + if len(frame_count) >= 2: + mock_session.close() + + # Put multiple frames in queue (use put_nowait to avoid blocking on full queue) + for i in range(5): + test_frame = np.zeros((100, 100, 3), dtype=np.uint8) + test_metadata = VideoMetadata(frame_id=i, received_at=datetime.now()) + mock_session._video_queue.put_nowait((test_frame, test_metadata)) + + # Mock state as started + mock_session._state = SessionState.STARTED + + mock_session.run() + + # Should have stopped after 2 frames (when close() was called) + assert len(frame_count) == 2 + + def test_run_handles_handler_exceptions_gracefully(self, mock_session): + """Test that exceptions in handlers don't crash run().""" + handler1_calls = [] + handler2_calls = [] + + @mock_session.on_frame + def failing_handler(frame, metadata): + handler1_calls.append(True) + raise ValueError("Handler error") + + @mock_session.on_frame + def working_handler(frame, metadata): + handler2_calls.append(True) + mock_session.close() + + # Put a frame in queue + test_frame = np.zeros((100, 100, 3), dtype=np.uint8) + test_metadata = VideoMetadata(frame_id=1, received_at=datetime.now()) + mock_session._video_queue.put((test_frame, test_metadata)) + + mock_session._state = SessionState.STARTED + + # Run should not raise despite first handler failing + mock_session.run() + + # Both handlers should have been called + assert len(handler1_calls) == 1 + assert len(handler2_calls) == 1 + + def test_run_closes_session_on_exception(self, mock_session): + """Test that run() closes session if exception occurs.""" + mock_session._state = SessionState.STARTED + + # Mock video() to raise an exception + def raise_exception(): + raise RuntimeError("Test error") + yield # Never reached + + with patch.object(mock_session, "video", return_value=raise_exception()): + with patch.object(mock_session, "close") as mock_close: + with pytest.raises(RuntimeError, match="Test error"): + mock_session.run() + + # Should have called close() + mock_close.assert_called_once() + + def test_run_closes_session_on_keyboard_interrupt(self, mock_session): + """Test that run() closes session on Ctrl+C.""" + mock_session._state = SessionState.STARTED + + # Mock video() to raise KeyboardInterrupt + def raise_interrupt(): + raise KeyboardInterrupt() + yield # Never reached + + with patch.object(mock_session, "video", return_value=raise_interrupt()): + with patch.object(mock_session, "close") as mock_close: + with pytest.raises(KeyboardInterrupt): + mock_session.run() + + # Should have called close() + mock_close.assert_called_once() + + +class TestDecorators: + """Tests for decorator registration.""" + + def test_on_frame_registration(self, mock_session): + """Test that on_frame decorator registers handler.""" + handler_called = [] + + @mock_session.on_frame + def process_frame(frame, metadata): + handler_called.append((frame, metadata)) + + assert len(mock_session._frame_handlers) == 1 + assert mock_session._frame_handlers[0] == process_frame + + def test_on_frame_multiple_handlers(self, mock_session): + """Test registering multiple frame handlers.""" + + @mock_session.on_frame + def handler1(frame, metadata): + pass + + @mock_session.on_frame + def handler2(frame, metadata): + pass + + assert len(mock_session._frame_handlers) == 2 + assert handler1 in mock_session._frame_handlers + assert handler2 in mock_session._frame_handlers + + def test_on_data_global_handler(self, mock_session): + """Test registering global data handler.""" + + @mock_session.on_data() + def handle_data(data, metadata): + pass + + assert mock_session._data_global_handler == handle_data + + def test_on_data_field_specific_handler(self, mock_session): + """Test registering field-specific data handler.""" + + @mock_session.on_data("predictions") + def handle_predictions(value, metadata): + pass + + assert "predictions" in mock_session._data_field_handlers + assert handle_predictions in mock_session._data_field_handlers["predictions"] + + +class TestVideoStream: + """Tests for video stream iterator.""" + + def test_video_auto_starts_session(self, mock_session): + """Test that video() automatically starts the session.""" + with patch.object(mock_session, "_ensure_started") as mock_ensure: + # Put a frame and end signal + test_frame = np.zeros((100, 100, 3), dtype=np.uint8) + test_metadata = VideoMetadata(frame_id=1, received_at=datetime.now()) + mock_session._video_queue.put((test_frame, test_metadata)) + mock_session._video_queue.put(None) # End stream + + # Iterate (should auto-start) + list(mock_session.video()) + + # Should have called _ensure_started + mock_ensure.assert_called_once() + + def test_video_yields_frame_tuples(self, mock_session): + """Test that video() yields (frame, metadata) tuples.""" + # Put test frames in queue + test_frame1 = np.zeros((100, 100, 3), dtype=np.uint8) + test_metadata1 = VideoMetadata(frame_id=1, received_at=datetime.now()) + + test_frame2 = np.ones((100, 100, 3), dtype=np.uint8) + test_metadata2 = VideoMetadata(frame_id=2, received_at=datetime.now()) + + mock_session._video_queue.put((test_frame1, test_metadata1)) + mock_session._video_queue.put((test_frame2, test_metadata2)) + mock_session._video_queue.put(None) # End stream + + # Mock _ensure_started + with patch.object(mock_session, "_ensure_started"): + # Iterate and collect + frames = [] + metadatas = [] + for frame, metadata in mock_session.video(): + frames.append(frame) + metadatas.append(metadata) + + assert len(frames) == 2 + assert np.array_equal(frames[0], test_frame1) + assert np.array_equal(frames[1], test_frame2) + + assert len(metadatas) == 2 + assert metadatas[0].frame_id == 1 + assert metadatas[1].frame_id == 2 + + +class TestWaitMethod: + """Tests for wait() method.""" + + def test_wait_auto_starts_session(self, mock_session): + """Test that wait() automatically starts the session.""" + with patch.object(mock_session, "_ensure_started") as mock_ensure: + # Put a frame and end signal + test_frame = np.zeros((100, 100, 3), dtype=np.uint8) + test_metadata = VideoMetadata(frame_id=1, received_at=datetime.now()) + + mock_session._video_queue.put((test_frame, test_metadata)) + mock_session._video_queue.put(None) # End stream + + # Should not raise and should consume all frames + mock_session.wait() + + # Should have called _ensure_started + mock_ensure.assert_called_once() + + def test_wait_blocks_until_stream_ends(self, mock_session): + """Test that wait() blocks until None is received.""" + with patch.object(mock_session, "_ensure_started"): + # Put frames in queue + test_frame = np.zeros((100, 100, 3), dtype=np.uint8) + test_metadata = VideoMetadata(frame_id=1, received_at=datetime.now()) + + mock_session._video_queue.put((test_frame, test_metadata)) + mock_session._video_queue.put(None) # End stream + + # Should not raise and should consume all frames + mock_session.wait() + + def test_wait_timeout(self, mock_session): + """Test that wait() raises TimeoutError on timeout.""" + with patch.object(mock_session, "_ensure_started"): + # Put a frame but no end signal + test_frame = np.zeros((100, 100, 3), dtype=np.uint8) + test_metadata = VideoMetadata(frame_id=1, received_at=datetime.now()) + mock_session._video_queue.put((test_frame, test_metadata)) + + with pytest.raises(TimeoutError, match="timed out"): + mock_session.wait(timeout=0.1) + + +class TestCloseMethod: + """Tests for close() method.""" + + def test_close_can_be_called_from_handler(self, mock_session): + """Test that close() can be called from within a frame handler.""" + calls = [] + + @mock_session.on_frame + def handler(frame, metadata): + calls.append(1) + mock_session.close() + + # Put frames in queue (use put_nowait to avoid blocking on full queue) + for i in range(5): + test_frame = np.zeros((100, 100, 3), dtype=np.uint8) + test_metadata = VideoMetadata(frame_id=i, received_at=datetime.now()) + mock_session._video_queue.put_nowait((test_frame, test_metadata)) + + mock_session._state = SessionState.STARTED + mock_session.run() + + # Should have stopped after first frame + assert len(calls) == 1 + assert mock_session._state == SessionState.CLOSED