diff --git a/examples/webrtc/webcam.py b/examples/webrtc/webcam.py index 5b9f752a9a..c156c16336 100644 --- a/examples/webrtc/webcam.py +++ b/examples/webrtc/webcam.py @@ -423,7 +423,7 @@ def start_loop(loop: asyncio.AbstractEventLoop): message = json.dumps( WebRTCData( stream_output=None, - data_output="", + data_output=[], ).model_dump() ) logger.info("Turning off data output via data channel") @@ -436,7 +436,7 @@ def start_loop(loop: asyncio.AbstractEventLoop): message = json.dumps( WebRTCData( stream_output=None, - data_output=output_name, + data_output=[output_name], ).model_dump() ) logger.info("Setting data output via data channel: %s", output_name) diff --git a/examples/webrtc/webrtc_worker.py b/examples/webrtc/webrtc_worker.py index 3b21804953..1823679e66 100644 --- a/examples/webrtc/webrtc_worker.py +++ b/examples/webrtc/webrtc_worker.py @@ -271,12 +271,51 @@ def parse_args() -> argparse.Namespace: def main(): args = parse_args() + logger.info(f"Starting WebRTC worker with output_mode: {args.output_mode}") + if args.output_mode == "data_only": + logger.info( + "DATA_ONLY mode: Server will send JSON data via data channel only (no video track)" + ) + elif args.output_mode == "video_only": + logger.info( + "VIDEO_ONLY mode: Server will send processed video only (no data channel messages)" + ) + elif args.output_mode == "both": + logger.info("BOTH mode: Server will send both video and JSON data") + workflow_specification = get_workflow_specification( api_key=args.api_key, workspace_id=args.workspace_id, workflow_id=args.workflow_id, ) + # Find available outputs + workflow_outputs = workflow_specification.get("outputs", []) + available_output_names = [o.get("name") for o in workflow_outputs] + + if not workflow_outputs: + logger.warning("āš ļø Workflow has no outputs defined") + else: + logger.info(f"Available workflow outputs: {available_output_names}") + + # Determine data_output + data_output_to_use = None + if args.data_outputs: + if args.data_outputs.lower() == "all": + data_output_to_use = None # None = all outputs + logger.info("data_output: ALL outputs (None)") + elif args.data_outputs.lower() == "none": + data_output_to_use = [] # Empty = no data + logger.info("data_output: NO outputs ([])") + else: + requested_fields = [f.strip() for f in args.data_outputs.split(",")] + data_output_to_use = requested_fields + logger.info(f"data_output: {data_output_to_use}") + else: + # Default: send all outputs + data_output_to_use = None + logger.info("data_output: ALL outputs (default)") + webrtc_turn_config = None if args.turn_url: webrtc_turn_config = WebRTCTURNConfig( @@ -316,8 +355,9 @@ def main(): sdp=peer_connection.localDescription.sdp, ), webrtc_turn_config=webrtc_turn_config, - stream_output=["video"], - data_output=["preds"], + output_mode=args.output_mode, + stream_output=args.stream_output if args.stream_output else None, + data_output=data_output_to_use, webrtc_realtime_processing=args.realtime, rtsp_url=args.source if is_rtmp_url(args.source) else None, processing_timeout=args.processing_timeout, @@ -340,6 +380,46 @@ def main(): if response.status_code != 200: raise Exception(f"Failed to initialise WebRTC pipeline: {response.text}") + # Set up data channel listener for JSON data + data_channel_message_count = [0] # Use list for closure + if peer_connection.data_channel: + + @peer_connection.data_channel.on("message") + def on_data_message(message): + data_channel_message_count[0] += 1 + try: + data = json.loads(message) + logger.info( + f"=== Data Channel Message #{data_channel_message_count[0]} ===" + ) + logger.info( + f"Frame ID: {data.get('video_metadata', {}).get('frame_id')}" + ) + + if data.get("serialized_output_data"): + logger.info( + f"Output fields: {list(data['serialized_output_data'].keys())}" + ) + for field, value in data["serialized_output_data"].items(): + if isinstance(value, str) and value.startswith("data:image"): + logger.info( + f" {field}: " + ) + elif isinstance(value, dict) and "predictions" in value: + logger.info( + f" {field}: {len(value.get('predictions', []))} detections" + ) + else: + logger.info(f" {field}: {value}") + else: + logger.info(" No output data") + + if data.get("errors"): + logger.warning(f"Errors: {data['errors']}") + + except json.JSONDecodeError: + logger.error(f"Failed to parse message: {message[:200]}...") + future = asyncio.run_coroutine_threadsafe( peer_connection.setRemoteDescription( RTCSessionDescription(sdp=webrtc_answer["sdp"], type=webrtc_answer["type"]) @@ -348,28 +428,201 @@ def main(): ) future.result() - while not peer_connection.closed_event.is_set(): - frame: Optional[VideoFrame] = peer_connection.stream_track.recv_queue.sync_get() - if frame is None: - logger.info("No more frames") - break - cv.imshow("Processed frame", frame.to_ndarray(format="bgr24")) - key = cv.waitKey(1) + # Track active data output mode: "all" (None), "none" ([]), or list of field names + active_data_fields = [] # Initialize for custom mode + if data_output_to_use is None: + active_data_mode = "all" # "all" means None + elif data_output_to_use == []: + active_data_mode = "none" # "none" means [] + else: + active_data_mode = "custom" # Custom list + active_data_fields = list(data_output_to_use) # Copy of active fields + + def draw_mode_indicator(frame, mode_text): + """Draw mode indicator overlay (top-left)""" + font = cv.FONT_HERSHEY_SIMPLEX + font_scale = 0.7 + font_thickness = 2 + + # Get text size to draw proper background + (text_width, text_height), baseline = cv.getTextSize( + mode_text, font, font_scale, font_thickness + ) + + # Draw background rectangle + padding = 10 + bg_x1, bg_y1 = 10, 10 + bg_x2, bg_y2 = 10 + text_width + padding * 2, 10 + text_height + padding * 2 + + cv.rectangle(frame, (bg_x1, bg_y1), (bg_x2, bg_y2), (0, 0, 0), -1) + + # Draw text + text_x = bg_x1 + padding + text_y = bg_y1 + padding + text_height + cv.putText( + frame, + mode_text, + (text_x, text_y), + font, + font_scale, + (100, 255, 100), # Brighter green + font_thickness, + cv.LINE_AA, + ) + + return frame + + def draw_controls_hint(frame, controls_text): + """Draw controls hint overlay (bottom)""" + font = cv.FONT_HERSHEY_SIMPLEX + controls_font_scale = 0.45 + controls_thickness = 1 + + h = frame.shape[0] + + (ctrl_width, ctrl_height), ctrl_baseline = cv.getTextSize( + controls_text, font, controls_font_scale, controls_thickness + ) + + # Draw background for controls + ctrl_padding = 8 + ctrl_bg_x1 = 10 + ctrl_bg_y1 = h - ctrl_height - ctrl_padding * 2 - 10 + ctrl_bg_x2 = ctrl_bg_x1 + ctrl_width + ctrl_padding * 2 + ctrl_bg_y2 = h - 10 + + cv.rectangle( + frame, (ctrl_bg_x1, ctrl_bg_y1), (ctrl_bg_x2, ctrl_bg_y2), (0, 0, 0), -1 + ) + + # Draw controls text + ctrl_text_x = ctrl_bg_x1 + ctrl_padding + ctrl_text_y = ctrl_bg_y2 - ctrl_padding - ctrl_baseline + cv.putText( + frame, + controls_text, + (ctrl_text_x, ctrl_text_y), + font, + controls_font_scale, + (200, 200, 200), # Light gray + controls_thickness, + cv.LINE_AA, + ) + + return frame + + def draw_output_list(frame, available_outputs, current_mode, active_fields=None): + """Draw list of available outputs with active indicators""" + font = cv.FONT_HERSHEY_SIMPLEX + x_start = 10 + y_start = 80 + line_height = 22 + + # Title + if current_mode == "all": + title = "Data Outputs (ALL)" + title_color = (100, 255, 100) + elif current_mode == "none": + title = "Data Outputs (NONE)" + title_color = (100, 100, 100) + else: + title = f"Data Outputs ({len(active_fields)} active)" + title_color = (100, 200, 255) + + cv.putText( + frame, title, (x_start, y_start), font, 0.5, title_color, 1, cv.LINE_AA + ) + y_start += line_height + 5 + + # Draw each output + for i, output in enumerate(available_outputs): + key_letter = chr(ord("a") + i) if i < 26 else "?" + output_name = output.get("name", "unnamed") + + # Determine if active + if current_mode == "all": + is_active = True + elif current_mode == "none": + is_active = False + else: + is_active = output_name in active_fields + + # Format line with ASCII checkbox + indicator = "[X]" if is_active else "[ ]" + color = (100, 255, 100) if is_active else (100, 100, 100) + text = f" [{key_letter}] {indicator} {output_name}" + + cv.putText( + frame, + text, + (x_start, y_start + i * line_height), + font, + 0.45, + color, + 1, + cv.LINE_AA, + ) + + # Controls + y_controls = y_start + len(available_outputs) * line_height + 10 + cv.putText( + frame, + " [+] All [-] None", + (x_start, y_controls), + font, + 0.45, + (200, 200, 200), + 1, + cv.LINE_AA, + ) + + return frame + + def handle_keyboard_input(key: int) -> bool: + nonlocal active_data_mode if key == -1: - continue + return True if key == ord("q"): logger.info("Quitting") - break + return False - if chr(key) in "1234567890abcdefghijkz" and ( + # Check data channel status for all commands except quit + if ( not peer_connection.data_channel or peer_connection.data_channel.readyState != "open" ): logger.error("Data channel not open") - continue + return True + + # Handle + key (all outputs) + if key == ord("+") or key == ord("="): + logger.info("Setting data_output to ALL (None)") + active_data_mode = "all" + message = json.dumps( + WebRTCData( + stream_output=None, + data_output=None, + ).model_dump() + ) + peer_connection.data_channel.send(message) + return True + + # Handle - key (no outputs) + if key == ord("-"): + logger.info("Setting data_output to NONE ([])") + active_data_mode = "none" + message = json.dumps( + WebRTCData( + stream_output=None, + data_output=[], + ).model_dump() + ) + peer_connection.data_channel.send(message) + return True + # Handle 0-9 keys (stream output selection) if chr(key) in "1234567890": if chr(key) == "0": message = json.dumps( @@ -393,36 +646,132 @@ def main(): ) logger.info("Setting stream output via data channel: %s", output_name) peer_connection.data_channel.send(message) - - if chr(key) in "abcdefghijkz": - if chr(key) == "z": + return True + + # Handle a-z toggle (individual field toggle) + if chr(key).isalpha() and chr(key).lower() in "abcdefghijklmnopqrstuvwxyz": + key_index = ord(chr(key).lower()) - ord("a") + if key_index < len(workflow_outputs): + output_name = workflow_outputs[key_index].get("name", "") + + # Toggle logic + if active_data_mode == "all": + # Was "all", switch to custom with all except this one + active_data_mode = "custom" + active_data_fields.clear() + active_data_fields.extend([o.get("name") for o in workflow_outputs]) + active_data_fields.remove(output_name) + logger.info(f"Toggled OFF '{output_name}' (was ALL)") + elif active_data_mode == "none": + # Was "none", enable only this field + active_data_mode = "custom" + active_data_fields.clear() + active_data_fields.append(output_name) + logger.info(f"Toggled ON '{output_name}' (was NONE)") + else: + # Custom mode - toggle + if output_name in active_data_fields: + active_data_fields.remove(output_name) + logger.info(f"Toggled OFF '{output_name}'") + else: + active_data_fields.append(output_name) + logger.info(f"Toggled ON '{output_name}'") + + # Send updated list directly as array + logger.info(f"Active fields: {active_data_fields}") message = json.dumps( WebRTCData( stream_output=None, - data_output="", + data_output=active_data_fields if active_data_fields else [], ).model_dump() ) - logger.info("Turning off data output via data channel") - else: - max_ind = max(0, len(workflow_specification.get("outputs", [])) - 1) - output_ind = min(key - ord("a"), max_ind) - output_name = workflow_specification.get("outputs")[output_ind].get( - "name", "" + peer_connection.data_channel.send(message) + return True + + return True + + # For data_only mode, show blank window with data controls + if args.output_mode == "data_only": + logger.info("DATA_ONLY mode: Showing placeholder window with output controls") + + try: + while not peer_connection.closed_event.is_set(): + # Create black frame with overlays + frame = np.zeros((520, 700, 3), dtype=np.uint8) + + mode_text = f"MODE: {args.output_mode.upper()}" + frame = draw_mode_indicator(frame, mode_text) + + if active_data_mode == "custom": + frame = draw_output_list( + frame, workflow_outputs, active_data_mode, active_data_fields + ) + else: + frame = draw_output_list( + frame, workflow_outputs, active_data_mode, None + ) + + controls_text = "+ = all | - = none | a-z = data | q = quit" + frame = draw_controls_hint(frame, controls_text) + + cv.imshow("WebRTC Worker - Interactive Mode", frame) + key = cv.waitKey(100) # 100ms delay to keep UI responsive + + # Handle keyboard input + should_continue = handle_keyboard_input(key) + if not should_continue: + break + + except KeyboardInterrupt: + logger.info("Interrupted by user") + else: + # For modes with video, use the video frame loop + while not peer_connection.closed_event.is_set(): + frame: Optional[VideoFrame] = ( + peer_connection.stream_track.recv_queue.sync_get() + ) + if frame is None: + logger.info("No more frames") + break + + # Convert frame to numpy + np_frame = frame.to_ndarray(format="bgr24") + + # Draw overlays + mode_text = f"MODE: {args.output_mode.upper()}" + np_frame = draw_mode_indicator(np_frame, mode_text) + + if active_data_mode == "custom": + np_frame = draw_output_list( + np_frame, workflow_outputs, active_data_mode, active_data_fields ) - message = json.dumps( - WebRTCData( - stream_output=None, - data_output=output_name, - ).model_dump() + else: + np_frame = draw_output_list( + np_frame, workflow_outputs, active_data_mode, None ) - logger.info("Setting data output via data channel: %s", output_name) - peer_connection.data_channel.send(message) - cv.destroyAllWindows() - asyncio.run_coroutine_threadsafe( - peer_connection.stream_track.stop_recv_loop(), - asyncio_loop, - ).result() + controls_text = ( + "+ = all data | - = no data | 0-9 = stream | a-z = data | q = quit" + ) + np_frame = draw_controls_hint(np_frame, controls_text) + + cv.imshow("WebRTC Worker - Interactive Mode", np_frame) + key = cv.waitKey(1) + + # Handle keyboard input + should_continue = handle_keyboard_input(key) + if not should_continue: + break + + # Cleanup + cv.destroyAllWindows() # Close OpenCV windows (works for all modes now) + + if args.output_mode != "data_only": + asyncio.run_coroutine_threadsafe( + peer_connection.stream_track.stop_recv_loop(), + asyncio_loop, + ).result() + if peer_connection.connectionState != "closed": logger.info("Closing WebRTC connection") asyncio.run_coroutine_threadsafe( diff --git a/examples/webrtc_sdk/data_only_example.py b/examples/webrtc_sdk/data_only_example.py new file mode 100644 index 0000000000..b273081883 --- /dev/null +++ b/examples/webrtc_sdk/data_only_example.py @@ -0,0 +1,182 @@ +""" +WebRTC SDK example demonstrating DATA_ONLY output mode. + +This example shows how to use the DATA_ONLY mode to receive only inference +results without video feedback, which significantly reduces bandwidth usage. + +DATA_ONLY mode is ideal for: +- Analytics and metrics collection +- Headless inference servers +- High-throughput object counting +- Logging detections for later analysis +- IoT devices with limited bandwidth + +Usage: + python examples/webrtc_sdk/data_only_example.py \ + --workspace-name \ + --workflow-id \ + [--api-url http://localhost:9001] \ + [--api-key ] \ + [--duration 30] + +Press Ctrl+C to stop early. +""" +import argparse +import time +from collections import defaultdict +from datetime import datetime + +from inference_sdk import InferenceHTTPClient +from inference_sdk.webrtc import OutputMode, StreamConfig, VideoMetadata, WebcamSource + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser("WebRTC SDK Data-Only Mode Example") + p.add_argument("--api-url", default="http://localhost:9001") + p.add_argument("--workspace-name", required=True) + p.add_argument("--workflow-id", required=True) + p.add_argument("--image-input-name", default="image") + p.add_argument("--api-key", default=None) + p.add_argument( + "--duration", + type=int, + default=30, + help="How long to run in seconds (default: 30)", + ) + p.add_argument( + "--data-fields", + type=str, + default=None, + help="Comma-separated list of fields to receive (default: all outputs)", + ) + return p.parse_args() + + +def main() -> None: + args = parse_args() + client = InferenceHTTPClient.init(api_url=args.api_url, api_key=args.api_key) + + # Prepare source + source = WebcamSource() + + # Configure data output fields + if args.data_fields: + data_output = [f.strip() for f in args.data_fields.split(",")] + else: + data_output = [] # Empty list means all outputs + + # Configure for DATA_ONLY mode - no video will be sent back + config = StreamConfig( + output_mode=OutputMode.DATA_ONLY, # Only data, no video + data_output=data_output, # What fields to receive + realtime_processing=True, # Process frames in realtime + ) + + # Statistics tracking + stats = { + "frames_processed": 0, + "start_time": time.time(), + "detections_per_frame": [], + "field_counts": defaultdict(int), + } + + print("\n" + "=" * 70) + print("WebRTC SDK - DATA_ONLY Mode Example") + print("=" * 70) + print(f"\nConfiguration:") + print(f" Output Mode: DATA_ONLY (no video feedback)") + print(f" Data Fields: {data_output if data_output else 'ALL workflow outputs'}") + print(f" Duration: {args.duration} seconds") + print(f" API URL: {args.api_url}") + print(f"\nStarting session... (Press Ctrl+C to stop early)") + print("-" * 70 + "\n") + + # Start streaming session + with client.webrtc.stream( + source=source, + workflow=args.workflow_id, + workspace=args.workspace_name, + image_input=args.image_input_name, + config=config, + ) as session: + + # Global data handler - receives all workflow outputs + @session.on_data() + def handle_all_data(data: dict, metadata: VideoMetadata): + stats["frames_processed"] += 1 + frame_num = stats["frames_processed"] + + # Track which fields we received + if data: + for field_name in data.keys(): + stats["field_counts"][field_name] += 1 + + # Print periodic updates with property_definition value + if frame_num % 10 == 0: + elapsed = time.time() - stats["start_time"] + fps = frame_num / elapsed if elapsed > 0 else 0 + + # Extract property_definition value if present + property_value = data.get("property_definition", "N/A") if data else "N/A" + + print( + f"Frame {frame_num:4d} | " + f"FPS: {fps:5.1f} | " + f"property_definition: {property_value} | " + f"Fields: {list(data.keys()) if data else 'none'}" + ) + + # Field-specific handler for predictions (if available) + @session.on_data("predictions") + def handle_predictions(predictions: dict, metadata: VideoMetadata): + # Count detections + if isinstance(predictions, dict) and "predictions" in predictions: + num_detections = len(predictions["predictions"]) + stats["detections_per_frame"].append(num_detections) + + # Log significant events + if num_detections > 5: + print(f" → High activity: {num_detections} detections!") + + # Run for specified duration + start_time = time.time() + try: + while time.time() - start_time < args.duration: + time.sleep(0.1) # Small sleep to prevent busy loop + except KeyboardInterrupt: + print("\n\nStopped by user.") + + # Print final statistics + elapsed = time.time() - stats["start_time"] + print("\n" + "=" * 70) + print("Session Statistics") + print("=" * 70) + print(f"\nDuration: {elapsed:.1f} seconds") + print(f"Frames Processed: {stats['frames_processed']}") + print(f"Average FPS: {stats['frames_processed'] / elapsed:.1f}") + + if stats["field_counts"]: + print(f"\nFields Received:") + for field, count in sorted(stats["field_counts"].items()): + print(f" {field}: {count} frames") + + if stats["detections_per_frame"]: + total_detections = sum(stats["detections_per_frame"]) + avg_detections = total_detections / len(stats["detections_per_frame"]) + max_detections = max(stats["detections_per_frame"]) + print(f"\nDetection Statistics:") + print(f" Total Detections: {total_detections}") + print(f" Average per Frame: {avg_detections:.1f}") + print(f" Max in Single Frame: {max_detections}") + + print("\n" + "=" * 70) + print("\nšŸ’” Benefits of DATA_ONLY mode:") + print(" āœ“ Significantly reduced bandwidth (no video sent back)") + print(" āœ“ Lower latency for data processing") + print(" āœ“ Ideal for headless/server deployments") + print(" āœ“ Perfect for analytics and logging use cases") + print("\n") + + +if __name__ == "__main__": + main() diff --git a/examples/webrtc_sdk/dynamic_config_example.py b/examples/webrtc_sdk/dynamic_config_example.py new file mode 100644 index 0000000000..a490e70505 --- /dev/null +++ b/examples/webrtc_sdk/dynamic_config_example.py @@ -0,0 +1,316 @@ +""" +WebRTC SDK example demonstrating dynamic channel configuration. + +This example shows how to change stream and data outputs in real-time +during an active WebRTC session without reconnecting. Uses a workflow +specification directly (no need for workspace/workflow-id). + +Usage: + python examples/webrtc_sdk/dynamic_config_example.py \ + [--api-url http://localhost:9001] \ + [--api-key ] \ + [--width 1920] \ + [--height 1080] + +Controls: + q - Quit + + - Enable all data outputs + - - Disable all data outputs + a-z - Toggle individual data outputs + 0 - Disable video output + 1-9 - Switch video output + +The example uses a workflow specification defined in the code, so no need +for workspace/workflow-id parameters. Press keys in the preview window to +dynamically control which outputs are sent. +""" +import argparse +import json + +import cv2 + +from inference_sdk import InferenceHTTPClient +from inference_sdk.webrtc import VideoMetadata, WebcamSource, StreamConfig + +# Example workflow specification +# This is a simple workflow that runs object detection and provides outputs +WORKFLOW_SPEC_JSON = """{ + "version": "1.0", + "inputs": [ + { + "type": "InferenceImage", + "name": "image" + } + ], + "steps": [ + { + "type": "roboflow_core/relative_statoic_crop@v1", + "name": "relative_static_crop", + "images": "$inputs.image", + "x_center": 0.5, + "y_center": 0.5, + "width": 0, + "height": 0.5 + }, + { + "type": "roboflow_core/property_definition@v1", + "name": "property_definition", + "data": "$inputs.image", + "operations": [ + { + "type": "ExtractImageProperty", + "property_name": "aspect_ratio" + } + ] + }, + { + "type": "roboflow_core/image_blur@v1", + "name": "image_blur", + "image": "$inputs.image" + } + ], + "outputs": [ + { + "type": "JsonField", + "name": "image_blur", + "coordinates_system": "own", + "selector": "$steps.image_blur.image" + }, + { + "type": "JsonField", + "name": "image", + "coordinates_system": "own", + "selector": "$steps.relative_static_crop.crops" + }, + { + "type": "JsonField", + "name": "original_ratio", + "coordinates_system": "own", + "selector": "$steps.property_definition.output" + } + ] +}""" + +# Parse the JSON specification into a Python dict +WORKFLOW_SPEC = json.loads(WORKFLOW_SPEC_JSON) + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser("WebRTC SDK Dynamic Configuration Example") + p.add_argument("--api-url", default="http://localhost:9001") + p.add_argument("--image-input-name", default="image") + p.add_argument("--api-key", default=None) + p.add_argument("--width", type=int, default=None) + p.add_argument("--height", type=int, default=None) + return p.parse_args() + + +def main() -> None: + args = parse_args() + client = InferenceHTTPClient.init(api_url=args.api_url, api_key=args.api_key) + + # Extract available outputs from workflow specification + workflow_outputs = WORKFLOW_SPEC.get("outputs", []) + available_output_names = [o.get("name") for o in workflow_outputs] + + if not workflow_outputs: + print("āš ļø Workflow has no outputs defined") + return + + print(f"Available workflow outputs: {available_output_names}") + + # Prepare source + resolution = None + if args.width and args.height: + resolution = (args.width, args.height) + source = WebcamSource(resolution=resolution) + + # Start with some outputs configured + config = StreamConfig( + stream_output=[available_output_names[0]] if available_output_names else [], # Use first output + data_output=[] # Start with no data outputs + ) + + # Start streaming session with workflow specification + session = client.webrtc.stream( + source=source, + workflow=WORKFLOW_SPEC, # Pass workflow spec directly + image_input=args.image_input_name, + config=config, + ) + + with session: + # Track current configuration state for display + current_data_mode = "none" + active_data_fields = [] # For custom mode + + def draw_output_list(frame): + """Draw list of available outputs with active indicators""" + x_start = 10 + y_start = 80 + line_height = 22 + + # Title + if current_data_mode == "all": + title = "Data Outputs (ALL)" + title_color = (100, 255, 100) + elif current_data_mode == "none": + title = "Data Outputs (NONE)" + title_color = (100, 100, 100) + else: + title = f"Data Outputs ({len(active_data_fields)} active)" + title_color = (100, 200, 255) + + cv2.putText(frame, title, (x_start, y_start), cv2.FONT_HERSHEY_SIMPLEX, 0.5, title_color, 1, cv2.LINE_AA) + y_start += line_height + 5 + + # Draw each output + for i, output in enumerate(workflow_outputs): + key_letter = chr(ord("a") + i) if i < 26 else "?" + output_name = output.get("name", "unnamed") + + # Determine if active + if current_data_mode == "all": + is_active = True + elif current_data_mode == "none": + is_active = False + else: + is_active = output_name in active_data_fields + + # Format line with ASCII checkbox + indicator = "[X]" if is_active else "[ ]" + color = (100, 255, 100) if is_active else (100, 100, 100) + text = f" [{key_letter}] {indicator} {output_name}" + + cv2.putText( + frame, + text, + (x_start, y_start + i * line_height), + cv2.FONT_HERSHEY_SIMPLEX, + 0.45, + color, + 1, + cv2.LINE_AA + ) + + # Controls + y_controls = y_start + len(workflow_outputs) * line_height + 10 + cv2.putText( + frame, + " [+] All [-] None [1-9] Video Output", + (x_start, y_controls), + cv2.FONT_HERSHEY_SIMPLEX, + 0.45, + (200, 200, 200), + 1, + cv2.LINE_AA + ) + + @session.on_frame + def show_frame(frame, metadata): + nonlocal current_data_mode, active_data_fields + + # Draw output list overlay + draw_output_list(frame) + + # Add controls hint at bottom + controls = "q=quit | +=all | -=none | a-z=toggle data | 0-9=video" + cv2.putText( + frame, + controls, + (10, frame.shape[0] - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (200, 200, 200), + 1 + ) + + cv2.imshow("WebRTC SDK - Dynamic Configuration", frame) + + # Handle keyboard input + key = cv2.waitKey(1) & 0xFF + + if key == ord("q"): + print("Quitting...") + session.stop() + + elif key == ord("+") or key == ord("="): + print("Setting data output to ALL") + session.set_data_outputs(None) + current_data_mode = "all" + + elif key == ord("-"): + print("Setting data output to NONE") + session.set_data_outputs([]) + current_data_mode = "none" + + elif key == ord("0"): + print("Disabling video output") + session.set_stream_output("") + + # Handle 1-9 keys for video output selection + elif ord("1") <= key <= ord("9"): + output_index = key - ord("1") + if output_index < len(available_output_names): + output_name = available_output_names[output_index] + print(f"Switching video to '{output_name}'") + session.set_stream_output(output_name) + + # Handle a-z keys for data output toggling + elif chr(key).isalpha() and chr(key).lower() in "abcdefghijklmnopqrstuvwxyz": + key_index = ord(chr(key).lower()) - ord("a") + if key_index < len(workflow_outputs): + output_name = workflow_outputs[key_index].get("name", "") + + # Toggle logic + if current_data_mode == "all": + # Was "all", switch to custom with all except this one + current_data_mode = "custom" + active_data_fields = list(available_output_names) + active_data_fields.remove(output_name) + print(f"Toggled OFF '{output_name}' (was ALL)") + elif current_data_mode == "none": + # Was "none", enable only this field + current_data_mode = "custom" + active_data_fields = [output_name] + print(f"Toggled ON '{output_name}' (was NONE)") + else: + # Custom mode - toggle + if output_name in active_data_fields: + active_data_fields.remove(output_name) + print(f"Toggled OFF '{output_name}'") + else: + active_data_fields.append(output_name) + print(f"Toggled ON '{output_name}'") + + # Send updated list + print(f"Active fields: {active_data_fields}") + session.set_data_outputs(active_data_fields if active_data_fields else []) + + # Global data handler to monitor what we're receiving + @session.on_data() + def handle_data(data: dict, metadata: VideoMetadata): + if data: + print(f"Frame {metadata.frame_id}: Received fields: {list(data.keys())}") + else: + print(f"Frame {metadata.frame_id}: No data (metadata only)") + + # Run the session (blocks until stop() is called or stream ends) + print("\n=== WebRTC Dynamic Configuration Example ===") + print(f"Available outputs: {available_output_names}") + print("\nControls:") + print(" q - Quit") + print(" + - Enable all data outputs") + print(" - - Disable all data outputs (metadata only)") + for i, output in enumerate(workflow_outputs): + key_letter = chr(ord("a") + i) if i < 26 else "?" + print(f" {key_letter} - Toggle '{output.get('name')}' data output") + print(" 0 - Disable video output") + for i, name in enumerate(available_output_names[:9]): + print(f" {i+1} - Switch video to '{name}'") + print("\nPress keys in the video window to control outputs dynamically.\n") + + session.run() + + +if __name__ == "__main__": + main() diff --git a/inference/core/exceptions.py b/inference/core/exceptions.py index cd824a9706..c93b64f5de 100644 --- a/inference/core/exceptions.py +++ b/inference/core/exceptions.py @@ -216,3 +216,7 @@ def __init__(self, message: str, inner_error: Exception): @property def inner_error(self) -> Exception: return self._inner_error + + +class WebRTCConfigurationError(Exception): + pass diff --git a/inference/core/interfaces/http/error_handlers.py b/inference/core/interfaces/http/error_handlers.py index e0d2287d89..b1d360ac9e 100644 --- a/inference/core/interfaces/http/error_handlers.py +++ b/inference/core/interfaces/http/error_handlers.py @@ -28,6 +28,7 @@ RoboflowAPITimeoutError, RoboflowAPIUnsuccessfulRequestError, ServiceConfigurationError, + WebRTCConfigurationError, WorkspaceLoadError, ) from inference.core.interfaces.stream_manager.api.errors import ( @@ -358,6 +359,15 @@ def wrapped_route(*args, **kwargs): "inner_error_type": error.inner_error_type, }, ) + except WebRTCConfigurationError as error: + logger.error("%s: %s", type(error).__name__, error) + resp = JSONResponse( + status_code=400, + content={ + "message": str(error), + "error_type": "WebRTCConfigurationError", + }, + ) except Exception as error: logger.exception("%s: %s", type(error).__name__, error) resp = JSONResponse(status_code=500, content={"message": "Internal error."}) @@ -661,6 +671,15 @@ async def wrapped_route(*args, **kwargs): "inner_error_type": error.inner_error_type, }, ) + except WebRTCConfigurationError as error: + logger.error("%s: %s", type(error).__name__, error) + resp = JSONResponse( + status_code=400, + content={ + "message": str(error), + "error_type": "WebRTCConfigurationError", + }, + ) except Exception as error: logger.exception("%s: %s", type(error).__name__, error) resp = JSONResponse(status_code=500, content={"message": "Internal error."}) diff --git a/inference/core/interfaces/http/http_api.py b/inference/core/interfaces/http/http_api.py index 3f4a14e871..0fb3930f9b 100644 --- a/inference/core/interfaces/http/http_api.py +++ b/inference/core/interfaces/http/http_api.py @@ -171,6 +171,7 @@ MissingServiceSecretError, RoboflowAPINotAuthorizedError, RoboflowAPINotNotFoundError, + WebRTCConfigurationError, WorkspaceLoadError, ) from inference.core.interfaces.base import BaseInterface @@ -1467,6 +1468,7 @@ async def initialise_webrtc_worker( "RoboflowAPINotAuthorizedError": RoboflowAPINotAuthorizedError, "RoboflowAPINotNotFoundError": RoboflowAPINotNotFoundError, "ValidationError": ValidationError, + "WebRTCConfigurationError": WebRTCConfigurationError, } exc = expected_exceptions.get( worker_result.exception_type, Exception diff --git a/inference/core/interfaces/stream_manager/manager_app/entities.py b/inference/core/interfaces/stream_manager/manager_app/entities.py index 578a105fbb..fb1e97950c 100644 --- a/inference/core/interfaces/stream_manager/manager_app/entities.py +++ b/inference/core/interfaces/stream_manager/manager_app/entities.py @@ -125,7 +125,7 @@ class InitialiseWebRTCPipelinePayload(InitialisePipelinePayload): class WebRTCData(BaseModel): stream_output: Optional[str] = None - data_output: Optional[str] = None + data_output: Optional[List[str]] = None class ConsumeResultsPayload(BaseModel): diff --git a/inference/core/interfaces/stream_manager/manager_app/webrtc.py b/inference/core/interfaces/stream_manager/manager_app/webrtc.py index f477f31c19..ae03157d69 100644 --- a/inference/core/interfaces/stream_manager/manager_app/webrtc.py +++ b/inference/core/interfaces/stream_manager/manager_app/webrtc.py @@ -338,7 +338,7 @@ def __init__( video_transform_track: VideoTransformTrack, asyncio_loop: asyncio.AbstractEventLoop, stream_output: Optional[str] = None, - data_output: Optional[str] = None, + data_output: Optional[List[str]] = None, *args, **kwargs, ): @@ -347,7 +347,7 @@ def __init__( self.video_transform_track: VideoTransformTrack = video_transform_track self._consumers_signalled: bool = False self.stream_output: Optional[str] = stream_output - self.data_output: Optional[str] = data_output + self.data_output: Optional[List[str]] = data_output self.data_channel: Optional[RTCDataChannel] = None @@ -384,7 +384,7 @@ async def init_rtc_peer_connection( webrtc_realtime_processing: bool = True, webcam_fps: Optional[float] = None, stream_output: Optional[str] = None, - data_output: Optional[str] = None, + data_output: Optional[List[str]] = None, ) -> RTCPeerConnectionWithFPS: relay = MediaRelay() video_transform_track = VideoTransformTrack( diff --git a/inference/core/interfaces/webrtc_worker/entities.py b/inference/core/interfaces/webrtc_worker/entities.py index 58699a28a9..5b8b15def8 100644 --- a/inference/core/interfaces/webrtc_worker/entities.py +++ b/inference/core/interfaces/webrtc_worker/entities.py @@ -10,6 +10,19 @@ ) +class WebRTCOutputMode(str, Enum): + """Defines the output mode for WebRTC worker processing. + + - DATA_ONLY: Only send JSON data via data channel (no video track sent back) + - VIDEO_ONLY: Only send processed video via video track (no data channel messages) + - BOTH: Send both video and data (default behavior) + """ + + DATA_ONLY = "data_only" + VIDEO_ONLY = "video_only" + BOTH = "both" + + class WebRTCWorkerRequest(BaseModel): api_key: Optional[str] = None workflow_configuration: WorkflowConfiguration @@ -18,6 +31,7 @@ class WebRTCWorkerRequest(BaseModel): webrtc_realtime_processing: bool = ( WEBRTC_REALTIME_PROCESSING # when set to True, MediaRelay.subscribe will be called with buffered=False ) + output_mode: WebRTCOutputMode = WebRTCOutputMode.BOTH stream_output: Optional[List[Optional[str]]] = Field(default_factory=list) data_output: Optional[List[Optional[str]]] = Field(default_factory=list) declared_fps: Optional[float] = None @@ -47,8 +61,15 @@ class WebRTCVideoMetadata(BaseModel): class WebRTCOutput(BaseModel): - output_name: Optional[str] = None - serialized_output_data: Optional[str] = None + """Output sent via WebRTC data channel. + + serialized_output_data contains a dictionary with workflow outputs: + - If data_output is None: all workflow outputs + - If data_output is []: None (no data sent) + - If data_output is ["field1", "field2"]: only those fields + """ + + serialized_output_data: Optional[Dict[str, Any]] = None video_metadata: Optional[WebRTCVideoMetadata] = None errors: List[str] = Field(default_factory=list) diff --git a/inference/core/interfaces/webrtc_worker/utils.py b/inference/core/interfaces/webrtc_worker/utils.py index 206c7aa0cd..8cebb061d3 100644 --- a/inference/core/interfaces/webrtc_worker/utils.py +++ b/inference/core/interfaces/webrtc_worker/utils.py @@ -21,7 +21,9 @@ def process_frame( inference_pipeline: InferencePipeline, stream_output: str, include_errors_on_frame: bool = True, -) -> Tuple[Dict[str, Union[WorkflowImageData, Any]], VideoFrame, List[str]]: +) -> Tuple[ + Dict[str, Union[WorkflowImageData, Any]], VideoFrame, List[str], Optional[str] +]: np_image = frame.to_ndarray(format="bgr24") workflow_output: Dict[str, Union[WorkflowImageData, Any]] = {} errors = [] @@ -41,6 +43,7 @@ def process_frame( workflow_output=workflow_output, frame_output_key=stream_output, ) + detected_output = None if result_np_image is None: for k in workflow_output.keys(): result_np_image = get_frame_from_workflow_output( @@ -48,9 +51,12 @@ def process_frame( frame_output_key=k, ) if result_np_image is not None: - errors.append( - f"'{stream_output}' not found in workflow outputs, using '{k}' instead" - ) + detected_output = k # Store detected output name + # Only show error if user explicitly specified an output that wasn't found + if stream_output is not None and stream_output != "": + errors.append( + f"'{stream_output}' not found in workflow outputs, using '{k}' instead" + ) break if result_np_image is None: errors.append("Visualisation blocks were not executed") @@ -72,6 +78,7 @@ def process_frame( workflow_output, VideoFrame.from_ndarray(result_np_image, format="bgr24"), errors, + detected_output, # Return the auto-detected output name (or None) ) diff --git a/inference/core/interfaces/webrtc_worker/webrtc.py b/inference/core/interfaces/webrtc_worker/webrtc.py index 906bd637cb..252e61a3a6 100644 --- a/inference/core/interfaces/webrtc_worker/webrtc.py +++ b/inference/core/interfaces/webrtc_worker/webrtc.py @@ -1,9 +1,12 @@ import asyncio +import base64 import datetime import json import logging -from typing import Callable, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple +import cv2 +import numpy as np import supervision as sv from aiortc import ( RTCConfiguration, @@ -14,6 +17,7 @@ VideoStreamTrack, ) from aiortc.contrib.media import MediaPlayer, MediaRelay, PlayerStreamTrack +from aiortc.mediastreams import MediaStreamError from aiortc.rtcrtpreceiver import RemoteStreamTrack from av import VideoFrame from av import logging as av_logging @@ -29,7 +33,9 @@ MissingApiKeyError, RoboflowAPINotAuthorizedError, RoboflowAPINotNotFoundError, + WebRTCConfigurationError, ) +from inference.core.interfaces.camera.entities import VideoFrame as InferenceVideoFrame from inference.core.interfaces.camera.entities import VideoFrameProducer from inference.core.interfaces.stream.inference_pipeline import InferencePipeline from inference.core.interfaces.stream_manager.manager_app.entities import ( @@ -38,11 +44,13 @@ ) from inference.core.interfaces.webrtc_worker.entities import ( WebRTCOutput, + WebRTCOutputMode, WebRTCVideoMetadata, WebRTCWorkerRequest, WebRTCWorkerResult, ) from inference.core.interfaces.webrtc_worker.utils import process_frame +from inference.core.roboflow_api import get_workflow_specification from inference.core.workflows.core_steps.common.serializers import ( serialise_sv_detections, ) @@ -53,6 +61,145 @@ logging.getLogger("aiortc").setLevel(logging.WARNING) +def serialize_workflow_output( + output_data: Any, is_explicit_request: bool +) -> Tuple[Any, Optional[str]]: + """Serialize a workflow output value recursively. + + Args: + output_data: The workflow output value to serialize + is_explicit_request: True if field was explicitly requested in data_output + + Returns (serialized_value, error_message) + - serialized_value: The value ready for JSON serialization, or None if + skipped/failed + - error_message: Error string if serialization failed, None otherwise + + Image serialization rules: + - Images are NEVER serialized UNLESS explicitly requested in data_output list + - If explicit: serialize to base64 JPEG (quality 85) + - If implicit (data_output=None): skip images + + Handles nested structures recursively (dicts, lists) to ensure all complex + types are properly serialized. + """ + try: + # Handle WorkflowImageData (convert to base64 only if explicit) + if isinstance(output_data, WorkflowImageData): + if not is_explicit_request: + # Skip images when listing all outputs (data_output=None) + return None, None # Skip without error + + # Explicitly requested - serialize to base64 JPEG + try: + np_image = output_data.numpy_image + # Encode as JPEG with quality 85 (good quality, much smaller than PNG) + success, buffer = cv2.imencode( + ".jpg", np_image, [cv2.IMWRITE_JPEG_QUALITY, 85] + ) + if success: + base64_image = base64.b64encode(buffer).decode("utf-8") + return f"data:image/jpeg;base64,{base64_image}", None + else: + return None, "Failed to encode image as JPEG" + except Exception as e: + return None, f"Failed to serialize image: {str(e)}" + + # Handle sv.Detections (use existing serializer) + elif isinstance(output_data, sv.Detections): + try: + parsed_detections = serialise_sv_detections(output_data) + return parsed_detections, None + except Exception as e: + return None, f"Failed to serialize detections: {str(e)}" + + # Handle dict (serialize recursively) + elif isinstance(output_data, dict): + return _serialize_collection( + output_data.items(), is_explicit_request, as_dict=True + ) + + # Handle list (serialize recursively) + elif isinstance(output_data, list): + return _serialize_collection( + enumerate(output_data), is_explicit_request, as_dict=False + ) + + # Handle primitives (str, int, float, bool) + elif isinstance(output_data, (str, int, float, bool, type(None))): + return output_data, None + + # Handle numpy types + elif isinstance(output_data, (np.integer, np.floating)): + return output_data.item(), None + + # Handle numpy arrays + elif isinstance(output_data, np.ndarray): + try: + return output_data.tolist(), None + except Exception as e: + return None, f"Failed to serialize numpy array: {str(e)}" + + # Unknown type - convert to string as fallback + else: + return str(output_data), None + + except Exception as e: + return None, f"Unexpected error serializing output: {str(e)}" + + +def _serialize_collection( + items, is_explicit_request: bool, as_dict: bool +) -> Tuple[Any, Optional[str]]: + """Helper to serialize dict or list collections recursively. + + Args: + items: Iterator of (key, value) pairs for dict or (index, value) for list + is_explicit_request: Whether the parent field was explicitly requested + as_dict: True to return dict, False to return list + + Returns (serialized_collection, error_message) + + Note: If serialization fails for some fields, those fields are replaced with + error placeholders and the collection is still returned with valid fields. + The error message lists which fields failed. + + Error placeholder format: + - For dicts: {"__serialization_error__": "error message"} (key identifies field) + - For lists: {"__serialization_error__": "error message", "__field__": "index"} + """ + result = {} if as_dict else [] + errors = [] + + for key_or_idx, value in items: + serialized_value, error = serialize_workflow_output(value, is_explicit_request) + + if error: + # Store error info and add placeholder + errors.append(f"{key_or_idx}: {error}") + + if as_dict: + # For dict: key already identifies the field + result[key_or_idx] = {"__serialization_error__": error} + else: + # For list: include index in placeholder + result.append( + {"__serialization_error__": error, "__field__": str(key_or_idx)} + ) + elif serialized_value is not None: + if as_dict: + result[key_or_idx] = serialized_value + else: + result.append(serialized_value) + # else: skip None values (e.g., images when not explicit) + + # Return result with placeholders + error message listing failed fields + if errors: + error_message = f"Partial serialization - errors in: {'; '.join(errors)}" + return result, error_message + return result, None + + class RTCPeerConnectionWithLoop(RTCPeerConnection): def __init__( self, @@ -64,28 +211,111 @@ def __init__( self._loop = asyncio_loop -class VideoTransformTrackWithLoop(VideoStreamTrack): +class VideoFrameProcessor: + """Base class for processing video frames through workflow. + + Can be used independently for data-only processing (no video track output) + or as a base for VideoTransformTrackWithLoop when video output is needed. + """ + def __init__( self, asyncio_loop: asyncio.AbstractEventLoop, workflow_configuration: WorkflowConfiguration, api_key: str, - data_output: Optional[str] = None, + data_output: Optional[List[str]] = None, stream_output: Optional[str] = None, + output_mode: WebRTCOutputMode = WebRTCOutputMode.BOTH, declared_fps: float = 30, termination_date: Optional[datetime.datetime] = None, terminate_event: Optional[asyncio.Event] = None, - *args, - **kwargs, ): - super().__init__(*args, **kwargs) self._loop = asyncio_loop self._termination_date = termination_date self._terminate_event = terminate_event self.track: Optional[RemoteStreamTrack] = None self._track_active: bool = False - self._av_logging_set: bool = False + self._received_frames = 0 + self._declared_fps = declared_fps + self._stop_processing = False + + self.output_mode = output_mode + self.stream_output = stream_output + self.data_channel: Optional[RTCDataChannel] = None + + # Normalize data_output to avoid edge cases + if data_output is None: + self.data_output = None + elif isinstance(data_output, list): + self.data_output = [f for f in data_output if f] + else: + raise WebRTCConfigurationError( + f"data_output must be list or None, got {type(data_output).__name__}" + ) + + # Validate that workflow is specified either by specification or workspace/workflow_id + has_specification = workflow_configuration.workflow_specification is not None + has_workspace_and_id = ( + workflow_configuration.workspace_name is not None + and workflow_configuration.workflow_id is not None + ) + + if not has_specification and not has_workspace_and_id: + raise WebRTCConfigurationError( + "Either 'workflow_specification' or both 'workspace_name' and 'workflow_id' must be provided" + ) + + # Fetch workflow_specification from API if not provided directly + if not has_specification and has_workspace_and_id: + try: + logger.info( + f"Fetching workflow specification for workspace={workflow_configuration.workspace_name}, " + f"workflow_id={workflow_configuration.workflow_id}" + ) + workflow_configuration.workflow_specification = ( + get_workflow_specification( + api_key=api_key, + workspace_id=workflow_configuration.workspace_name, + workflow_id=workflow_configuration.workflow_id, + ) + ) + # Clear workspace_name and workflow_id after fetch to avoid conflicts + # InferencePipeline requires these to be mutually exclusive with workflow_specification + workflow_configuration.workspace_name = None + workflow_configuration.workflow_id = None + except Exception as e: + raise WebRTCConfigurationError( + f"Failed to fetch workflow specification from API: {str(e)}" + ) + + # Validate data_output and stream_output against workflow specification + if workflow_configuration.workflow_specification is not None: + workflow_outputs = workflow_configuration.workflow_specification.get( + "outputs", [] + ) + available_output_names = [o.get("name") for o in workflow_outputs] + + # Validate data_output fields + if self.data_output is not None and len(self.data_output) > 0: + invalid_fields = [ + field + for field in self.data_output + if field not in available_output_names + ] + + if invalid_fields: + raise WebRTCConfigurationError( + f"Invalid data_output fields: {invalid_fields}. " + f"Available workflow outputs: {available_output_names}" + ) + + # Validate stream_output field (if explicitly specified and not empty) + if self.stream_output and self.stream_output not in available_output_names: + raise WebRTCConfigurationError( + f"Invalid stream_output field: '{self.stream_output}'. " + f"Available workflow outputs: {available_output_names}" + ) self._inference_pipeline = InferencePipeline.init_with_workflow( video_reference=VideoFrameProducer, @@ -99,28 +329,17 @@ def __init__( cancel_thread_pool_tasks_on_exit=workflow_configuration.cancel_thread_pool_tasks_on_exit, video_metadata_input_name=workflow_configuration.video_metadata_input_name, ) - self.data_output = data_output - self.stream_output = stream_output - self.data_channel: Optional[RTCDataChannel] = None - self._received_frames = 0 - self._declared_fps = declared_fps - def set_track( - self, - track: RemoteStreamTrack, - ): + def set_track(self, track: RemoteStreamTrack): if not self.track: self.track = track def close(self): self._track_active = False + self._stop_processing = True - async def recv(self): - # Silencing swscaler warnings in multi-threading environment - if not self._av_logging_set: - av_logging.set_libav_level(av_logging.ERROR) - self._av_logging_set = True - + def _check_termination(self): + """Check if we should terminate based on timeout""" if ( self._termination_date and self._termination_date < datetime.datetime.now() @@ -129,69 +348,255 @@ async def recv(self): ): logger.info("Timeout reached, terminating inference pipeline") self._terminate_event.set() + return True + return False + + def _process_frame_data_only( + self, frame: VideoFrame, frame_id: int + ) -> Tuple[Dict[str, Any], List[str]]: + """Process frame through workflow without rendering visuals. + + Returns (workflow_output, errors) + """ + np_image = frame.to_ndarray(format="bgr24") + workflow_output = {} + errors = [] + + try: + video_frame = InferenceVideoFrame( + image=np_image, + frame_id=frame_id, + frame_timestamp=datetime.datetime.now(), + comes_from_video_file=False, + fps=self._declared_fps, + measured_fps=self._declared_fps, + ) + workflow_output = self._inference_pipeline._on_video_frame([video_frame])[0] + except Exception as e: + logger.exception("Error in workflow processing") + errors.append(str(e)) + + return workflow_output, errors + + async def _send_data_output( + self, + workflow_output: Dict[str, Any], + frame_timestamp: datetime.datetime, + frame: VideoFrame, + errors: List[str], + ): + """Send data via data channel based on data_output configuration. + + - data_output = None: Send all workflow outputs + - data_output = []: Don't send any data (only metadata) + - data_output = ["field1", "field2"]: Send only specified fields + """ + if not self.data_channel or self.data_channel.readyState != "open": + return + + video_metadata = WebRTCVideoMetadata( + frame_id=self._received_frames, + received_at=frame_timestamp.isoformat(), + pts=frame.pts, + time_base=frame.time_base, + declared_fps=self._declared_fps, + ) + + webrtc_output = WebRTCOutput( + serialized_output_data=None, + video_metadata=video_metadata, + errors=list(errors), # Copy errors list + ) + + # Determine which fields to send + if self.data_output is None: + # Send ALL workflow outputs + fields_to_send = list(workflow_output.keys()) + elif len(self.data_output) == 0: + self.data_channel.send(json.dumps(webrtc_output.model_dump())) + return + else: + fields_to_send = self.data_output + + # Serialize each field + serialized_outputs = {} + + # Determine if this is an explicit request (fields listed) or implicit (all fields) + is_all_outputs = self.data_output is None + + for field_name in fields_to_send: + if field_name not in workflow_output: + webrtc_output.errors.append( + f"Requested output '{field_name}' not found in workflow outputs" + ) + continue + + output_data = workflow_output[field_name] + + # Determine if this field was explicitly requested + if is_all_outputs: + # data_output=None means listing all, so not explicit for individual fields + is_explicit_request = False + else: + # Field is in the data_output list, so it's explicit + is_explicit_request = True + + serialized_value, error = serialize_workflow_output( + output_data=output_data, + is_explicit_request=is_explicit_request, + ) + + if error: + # Add error to errors list and include placeholder in output + webrtc_output.errors.append(f"{field_name}: {error}") + serialized_outputs[field_name] = {"__serialization_error__": error} + elif serialized_value is not None: + serialized_outputs[field_name] = serialized_value + # else: serialized_value is None and no error = field was skipped (e.g., image in video track) + + # Only set serialized_output_data if we have data to send + if serialized_outputs: + webrtc_output.serialized_output_data = serialized_outputs + self.data_channel.send(json.dumps(webrtc_output.model_dump())) + + async def process_frames_data_only(self): + """Process frames for data extraction only, without video track output. + + This is used when output_mode is DATA_ONLY and no video track is needed. + """ + # Silencing swscaler warnings in multi-threading environment + if not self._av_logging_set: + av_logging.set_libav_level(av_logging.ERROR) + self._av_logging_set = True + + logger.info("Starting data-only frame processing") + + try: + while ( + self.track + and self.track.readyState != "ended" + and not self._stop_processing + ): + if self._check_termination(): + break + + # Drain queue if using PlayerStreamTrack (RTSP) + if isinstance(self.track, PlayerStreamTrack): + while self.track._queue.qsize() > 30: + self.track._queue.get_nowait() + + frame: VideoFrame = await self.track.recv() + self._received_frames += 1 + frame_timestamp = datetime.datetime.now() + + # Process workflow without rendering + loop = asyncio.get_running_loop() + workflow_output, errors = await loop.run_in_executor( + None, + self._process_frame_data_only, + frame, + self._received_frames, + ) + + # Send data via data channel + await self._send_data_output( + workflow_output, frame_timestamp, frame, errors + ) + + except asyncio.CancelledError: + logger.info("Data-only processing cancelled") + except MediaStreamError: + logger.info("Stream ended in data-only processing") + except Exception as exc: + logger.error("Error in data-only processing: %s", exc) + + +class VideoTransformTrackWithLoop(VideoStreamTrack, VideoFrameProcessor): + """Video track that processes frames through workflow and sends video back. + + Inherits from both VideoStreamTrack (for WebRTC video track functionality) + and VideoFrameProcessor (for workflow processing logic). + """ + + def __init__( + self, + asyncio_loop: asyncio.AbstractEventLoop, + workflow_configuration: WorkflowConfiguration, + api_key: str, + data_output: Optional[List[str]] = None, + stream_output: Optional[str] = None, + output_mode: WebRTCOutputMode = WebRTCOutputMode.BOTH, + declared_fps: float = 30, + termination_date: Optional[datetime.datetime] = None, + terminate_event: Optional[asyncio.Event] = None, + *args, + **kwargs, + ): + VideoStreamTrack.__init__(self, *args, **kwargs) + VideoFrameProcessor.__init__( + self, + asyncio_loop=asyncio_loop, + workflow_configuration=workflow_configuration, + api_key=api_key, + data_output=data_output, + stream_output=stream_output, + output_mode=output_mode, + declared_fps=declared_fps, + termination_date=termination_date, + terminate_event=terminate_event, + ) + + async def recv(self): + """Called by WebRTC to get the next frame to send. + + This method processes frames through the workflow and returns + the processed video frame for transmission. + """ + # Silencing swscaler warnings in multi-threading environment + if not self._av_logging_set: + av_logging.set_libav_level(av_logging.ERROR) + self._av_logging_set = True + + # Check if we should terminate + if self._check_termination(): + raise MediaStreamError("Processing terminated due to timeout") + + # Drain queue if using PlayerStreamTrack (RTSP) if isinstance(self.track, PlayerStreamTrack): while self.track._queue.qsize() > 30: self.track._queue.get_nowait() - frame: VideoFrame = await self.track.recv() + frame: VideoFrame = await self.track.recv() self._received_frames += 1 frame_timestamp = datetime.datetime.now() + + # Process frame through workflow WITH rendering (for video output) loop = asyncio.get_running_loop() - workflow_output, new_frame, errors = await loop.run_in_executor( - None, - process_frame, - frame, - self._received_frames, - self._inference_pipeline, - self.stream_output, + workflow_output, new_frame, errors, detected_output = ( + await loop.run_in_executor( + None, + process_frame, + frame, + self._received_frames, + self._inference_pipeline, + self.stream_output, + ) ) + # Update stream_output if it was auto-detected (only when None) + if self.stream_output is None and detected_output is not None: + self.stream_output = detected_output + logger.info(f"Auto-detected and set stream_output to: {detected_output}") + new_frame.pts = frame.pts new_frame.time_base = frame.time_base - if self.data_channel and self.data_channel.readyState == "open": - video_metadata = WebRTCVideoMetadata( - frame_id=self._received_frames, - received_at=frame_timestamp.isoformat(), - pts=new_frame.pts, - time_base=new_frame.time_base, - declared_fps=self._declared_fps, - ) - webrtc_output = WebRTCOutput( - output_name=None, - serialized_output_data=None, - video_metadata=video_metadata, - errors=errors, + # Send data via data channel if needed (BOTH or DATA_ONLY modes) + if self.output_mode in [WebRTCOutputMode.BOTH, WebRTCOutputMode.DATA_ONLY]: + await self._send_data_output( + workflow_output, frame_timestamp, frame, errors ) - if self.data_output and self.data_output in workflow_output: - workflow_output = workflow_output[self.data_output] - serialized_data = None - if isinstance(workflow_output, WorkflowImageData): - webrtc_output.errors.append( - f"Selected data output '{self.data_output}' contains image, please use video output instead" - ) - elif isinstance(workflow_output, sv.Detections): - try: - parsed_detections = serialise_sv_detections(workflow_output) - serialized_data = json.dumps(parsed_detections) - except Exception: - webrtc_output.errors.append( - f"Failed to serialise output: {self.data_output}" - ) - elif isinstance(workflow_output, dict): - try: - serialized_data = json.dumps(workflow_output) - except Exception: - webrtc_output.errors.append( - f"Failed to serialise output: {self.data_output}" - ) - else: - serialized_data = str(workflow_output) - if serialized_data is not None: - webrtc_output.output_name = self.data_output - webrtc_output.serialized_output_data = serialized_data - self.data_channel.send(json.dumps(webrtc_output.model_dump())) return new_frame @@ -229,25 +634,56 @@ async def init_rtc_peer_connection_with_loop( logger.info("Setting termination date to %s", termination_date) except (TypeError, ValueError): pass + output_mode = webrtc_request.output_mode stream_output = None + + # Normalize stream_output if webrtc_request.stream_output: - # TODO: UI sends None as stream_output for wildcard outputs - stream_output = webrtc_request.stream_output[0] or "" - data_output = None - if webrtc_request.data_output: - data_output = webrtc_request.data_output[0] + filtered = [s for s in webrtc_request.stream_output if s] + stream_output = filtered[0] if filtered else None + + # Handle data_output as list + # - None or not provided: send all outputs + # - []: send nothing + # - ["field1", "field2"]: send only those fields + data_output = ( + webrtc_request.data_output if webrtc_request.data_output is not None else None + ) + + # Determine if we should send video back based on output mode + should_send_video = output_mode in [ + WebRTCOutputMode.VIDEO_ONLY, + WebRTCOutputMode.BOTH, + ] try: - video_transform_track = VideoTransformTrackWithLoop( - asyncio_loop=asyncio_loop, - workflow_configuration=webrtc_request.workflow_configuration, - api_key=webrtc_request.api_key, - data_output=data_output, - stream_output=stream_output, - declared_fps=webrtc_request.declared_fps, - termination_date=termination_date, - terminate_event=terminate_event, - ) + # For DATA_ONLY mode, we use VideoFrameProcessor directly (no video track) + # For other modes, we use VideoTransformTrackWithLoop (includes video track) + if should_send_video: + video_processor = VideoTransformTrackWithLoop( + asyncio_loop=asyncio_loop, + workflow_configuration=webrtc_request.workflow_configuration, + api_key=webrtc_request.api_key, + data_output=data_output, + stream_output=stream_output, + output_mode=output_mode, + declared_fps=webrtc_request.declared_fps, + termination_date=termination_date, + terminate_event=terminate_event, + ) + else: + # DATA_ONLY or OFF mode - use base VideoFrameProcessor + video_processor = VideoFrameProcessor( + asyncio_loop=asyncio_loop, + workflow_configuration=webrtc_request.workflow_configuration, + api_key=webrtc_request.api_key, + data_output=data_output, + stream_output=stream_output, + output_mode=output_mode, + declared_fps=webrtc_request.declared_fps, + termination_date=termination_date, + terminate_event=terminate_event, + ) except ( ValidationError, MissingApiKeyError, @@ -261,6 +697,14 @@ async def init_rtc_peer_connection_with_loop( ) ) return + except WebRTCConfigurationError as error: + send_answer( + WebRTCWorkerResult( + exception_type=error.__class__.__name__, + error_message=str(error), + ) + ) + return except RoboflowAPINotAuthorizedError: send_answer( WebRTCWorkerResult( @@ -327,29 +771,44 @@ async def init_rtc_peer_connection_with_loop( "stimeout": "2000000", # 2s socket timeout }, ) - video_transform_track.set_track( - track=player.video, - ) - peer_connection.addTrack(video_transform_track) + video_processor.set_track(track=player.video) + + # Only add video track if we should send video back + if should_send_video: + peer_connection.addTrack(video_processor) + else: + # For DATA_ONLY, start data-only processing task + logger.info("Starting data-only processing for RTSP stream") + asyncio.create_task(video_processor.process_frames_data_only()) @peer_connection.on("track") def on_track(track: RemoteStreamTrack): - logger.info("track received") - video_transform_track.set_track( - track=relay.subscribe( - track, - buffered=False if webrtc_request.webrtc_realtime_processing else True, - ) + logger.info("Track received from client") + relayed_track = relay.subscribe( + track, + buffered=False if webrtc_request.webrtc_realtime_processing else True, ) - peer_connection.addTrack(video_transform_track) + video_processor.set_track(track=relayed_track) + + # Only add video track back if we should send video + if should_send_video: + logger.info(f"Output mode: {output_mode} - Adding video track to send back") + peer_connection.addTrack(video_processor) + else: + # For DATA_ONLY, start data-only processing task + logger.info( + f"Output mode: {output_mode} - Starting data-only processing (no video track)" + ) + asyncio.create_task(video_processor.process_frames_data_only()) @peer_connection.on("connectionstatechange") async def on_connectionstatechange(): logger.info("Connection state is %s", peer_connection.connectionState) if peer_connection.connectionState in {"failed", "closed"}: - if video_transform_track.track: - logger.info("Stopping video transform track") - video_transform_track.track.stop() + if video_processor.track: + logger.info("Stopping video processor track") + video_processor.track.stop() + video_processor.close() logger.info("Stopping WebRTC peer") await peer_connection.close() terminate_event.set() @@ -363,16 +822,19 @@ def on_datachannel(channel: RTCDataChannel): def on_message(message): logger.info("Data channel message received: %s", message) try: - message = WebRTCData(**json.loads(message)) + message_data = WebRTCData(**json.loads(message)) except json.JSONDecodeError: logger.error("Failed to decode webrtc data payload: %s", message) return - if message.stream_output is not None: - video_transform_track.stream_output = message.stream_output or None - if message.data_output is not None: - video_transform_track.data_output = message.data_output or None - video_transform_track.data_channel = channel + # Handle output changes (which workflow output to send) + if message_data.stream_output is not None: + video_processor.stream_output = message_data.stream_output or None + + if message_data.data_output is not None: + video_processor.data_output = message_data.data_output + + video_processor.data_channel = channel await peer_connection.setRemoteDescription( RTCSessionDescription( @@ -402,8 +864,9 @@ def on_message(message): if peer_connection.connectionState != "closed": logger.info("Closing WebRTC connection") await peer_connection.close() - if video_transform_track.track: - logger.info("Stopping video transform track") - video_transform_track.track.stop() + if video_processor.track: + logger.info("Stopping video processor track") + video_processor.track.stop() + video_processor.close() await usage_collector.async_push_usage_payloads() logger.info("WebRTC peer connection closed") diff --git a/inference_sdk/webrtc/__init__.py b/inference_sdk/webrtc/__init__.py index c9deb7baf8..bd336d5060 100644 --- a/inference_sdk/webrtc/__init__.py +++ b/inference_sdk/webrtc/__init__.py @@ -1,7 +1,7 @@ """WebRTC SDK for Inference - Unified streaming API.""" from .client import WebRTCClient # noqa: F401 -from .config import StreamConfig # noqa: F401 +from .config import OutputMode, StreamConfig # noqa: F401 from .session import VideoMetadata, WebRTCSession # noqa: F401 from .sources import ( # noqa: F401 ManualSource, @@ -17,6 +17,7 @@ "WebRTCSession", "StreamConfig", "VideoMetadata", + "OutputMode", # Source classes "StreamSource", "WebcamSource", diff --git a/inference_sdk/webrtc/config.py b/inference_sdk/webrtc/config.py index 6fc1460aa4..f4ca9f6448 100644 --- a/inference_sdk/webrtc/config.py +++ b/inference_sdk/webrtc/config.py @@ -1,9 +1,42 @@ """Configuration for WebRTC streaming sessions.""" from dataclasses import dataclass, field +from enum import Enum from typing import Any, Dict, List, Optional +class OutputMode(str, Enum): + """Output mode for WebRTC sessions. + + Determines what data is sent back from the server during processing: + + - DATA_ONLY: Only send JSON data via data channel (no video track sent back). + Use this when you only need inference results/metrics and want to + save bandwidth. The server won't send processed video frames back. + + - VIDEO_ONLY: Only send processed video via video track (no data channel messages). + Use this when you only need to display the processed video and don't + need programmatic access to results. + + - BOTH: Send both processed video and JSON data (default behavior). + Use this when you need both visual output and programmatic access to results. + + Examples: + # Data-only mode for analytics/logging (saves bandwidth) + config = StreamConfig(output_mode=OutputMode.DATA_ONLY) + + # Video-only mode for display-only applications + config = StreamConfig(output_mode=OutputMode.VIDEO_ONLY) + + # Both (default) for full-featured applications + config = StreamConfig(output_mode=OutputMode.BOTH) + """ + + DATA_ONLY = "data_only" + VIDEO_ONLY = "video_only" + BOTH = "both" + + @dataclass class StreamConfig: """Unified configuration for all WebRTC stream types. @@ -19,6 +52,9 @@ class StreamConfig: data_output: List[str] = field(default_factory=list) """List of workflow output names to receive via data channel""" + output_mode: OutputMode = OutputMode.BOTH + """Output mode: DATA_ONLY (data channel only), VIDEO_ONLY (video only), or BOTH (default)""" + # Processing configuration realtime_processing: bool = True """Whether to process frames in realtime (drop if can't keep up) or queue all frames""" diff --git a/inference_sdk/webrtc/session.py b/inference_sdk/webrtc/session.py index fd03243bf9..0738d32963 100644 --- a/inference_sdk/webrtc/session.py +++ b/inference_sdk/webrtc/session.py @@ -18,6 +18,9 @@ from inference_sdk.webrtc.config import StreamConfig from inference_sdk.webrtc.sources import StreamSource +# Sentinel value to distinguish "not provided" from "None" +_UNSET = object() + if TYPE_CHECKING: from aiortc import RTCDataChannel, RTCPeerConnection @@ -191,6 +194,7 @@ def __init__( self._frame_handlers: List[Callable] = [] self._data_field_handlers: dict[str, List[Callable]] = {} self._data_global_handler: Optional[Callable] = None + self._data_channel: Optional["RTCDataChannel"] = None # Public APIs self.video = _VideoStream(self, self._video_queue) @@ -477,6 +481,122 @@ def process(frame, metadata): self.close() raise + def _send_config_message( + self, + stream_output: Any = _UNSET, # noqa: ANN401 + data_output: Any = _UNSET, # noqa: ANN401 + ) -> None: + """Send configuration message to server via data channel. + + Args: + stream_output: Value to set for stream_output field (_UNSET = don't change) + data_output: Value to set for data_output field (_UNSET = don't change) + + Raises: + RuntimeError: If data channel is not open or not initialized + """ + if not self._data_channel: + raise RuntimeError( + "Data channel not initialized. This method can only be called " + "within the WebRTCSession context (after __enter__)." + ) + + if self._data_channel.readyState != "open": + raise RuntimeError( + f"Data channel is not open (state: {self._data_channel.readyState}). " + "Wait for the connection to be established before changing configuration." + ) + + # Build message dict with only the fields to change + message_dict = {} + if stream_output is not _UNSET: + message_dict["stream_output"] = stream_output + if data_output is not _UNSET: + message_dict["data_output"] = data_output + + if not message_dict: + # Nothing to send + return + + # Serialize and send + message_json = json.dumps(message_dict) + + # Send from the async event loop thread + def _send(): + self._data_channel.send(message_json) + + self._loop.call_soon_threadsafe(_send) + + def set_stream_output(self, output_name: Optional[str]) -> None: + """Change which workflow output is rendered on the video track. + + This allows dynamically switching which workflow output is used for + the video stream without reconnecting. Useful for workflows with + multiple visualization outputs. + + Args: + output_name: Name of workflow output to use for video rendering. + - None: Disable rendering / trigger auto-detection + - "" (empty string): Disable rendering / trigger auto-detection + - "output_name": Use specific workflow output + + Raises: + RuntimeError: If data channel is not open or not initialized + + Examples: + # Switch to specific output + session.set_stream_output("visualization") + + # Disable video rendering + session.set_stream_output(None) + + # Let server auto-detect best output + session.set_stream_output("") + + Note: + The server does not validate the output name. If you specify an + invalid output, errors will appear in the data channel responses. + """ + self._send_config_message(stream_output=output_name) + + def set_data_outputs(self, output_names: Optional[List[str]]) -> None: + """Change which workflow outputs are sent via data channel. + + This allows dynamically controlling which workflow outputs are sent + over the data channel without reconnecting. Useful for reducing + bandwidth or focusing on specific outputs. + + Args: + output_names: List of workflow output names to send. + - None: Send ALL workflow outputs + - []: Send NO outputs (metadata only) + - ["field1", "field2"]: Send only specified fields + + Raises: + RuntimeError: If data channel is not open or not initialized + + Examples: + # Send all outputs + session.set_data_outputs(None) + + # Send only metadata (no workflow outputs) + session.set_data_outputs([]) + + # Send specific fields + session.set_data_outputs(["predictions", "visualization"]) + + # Send single field + session.set_data_outputs(["predictions"]) + + Note: + - The server does not validate output names. Invalid names will + result in errors in the data channel responses. + - Images are only serialized when explicitly requested in the list. + - Using None (all outputs) will skip image serialization to save + bandwidth. + """ + self._send_config_message(data_output=output_names) + def _invoke_data_handler( self, handler: Callable, value: Any, metadata: Optional[VideoMetadata] ) -> None: # noqa: ANN401 @@ -654,7 +774,10 @@ async def _reader(): # Setup data channel ch = pc.createDataChannel("inference") - # Setup data channel message handler + # Store reference for dynamic configuration + self._data_channel = ch + + # Setup new data channel message handler @ch.on("message") def _on_data_message(message: Any) -> None: # noqa: ANN401 try: @@ -726,6 +849,7 @@ def _on_data_message(message: Any) -> None: # noqa: ANN401 "webrtc_realtime_processing": self._config.realtime_processing, "stream_output": self._config.stream_output, "data_output": self._config.data_output, + "output_mode": self._config.output_mode.value, } # Add TURN config if available (auto-fetched or user-provided) @@ -744,7 +868,16 @@ def _on_data_message(message: Any) -> None: # noqa: ANN401 url = f"{self._api_url}/initialise_webrtc_worker" headers = {"Content-Type": "application/json"} resp = requests.post(url, json=payload, headers=headers, timeout=90) - resp.raise_for_status() + try: + resp.raise_for_status() + except requests.exceptions.HTTPError as e: + # Try to get more details from the response + try: + error_detail = resp.json() + logger.error(f"Server returned error: {error_detail}") + except Exception: + logger.error(f"Server response body: {resp.text}") + raise ans: dict[str, Any] = resp.json() # Set remote description