diff --git a/agents-core/vision_agents/core/agents/agents.py b/agents-core/vision_agents/core/agents/agents.py index 9f6bee1f..04519143 100644 --- a/agents-core/vision_agents/core/agents/agents.py +++ b/agents-core/vision_agents/core/agents/agents.py @@ -13,8 +13,8 @@ from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import TrackType from ..edge import sfu_events -from ..edge.events import AudioReceivedEvent, TrackAddedEvent, CallEndedEvent -from ..edge.types import Connection, Participant, PcmData, User, OutputAudioTrack +from ..edge.events import AudioReceivedEvent, TrackAddedEvent, TrackRemovedEvent, CallEndedEvent +from ..edge.types import Connection, Participant, PcmData, User from ..events.manager import EventManager from ..llm import events as llm_events from ..llm.events import ( @@ -33,6 +33,8 @@ from ..tts.tts import TTS from ..tts.events import TTSAudioEvent from ..turn_detection import TurnDetector, TurnStartedEvent, TurnEndedEvent +from ..utils.video_forwarder import VideoForwarder +from ..utils.video_utils import ensure_even_dimensions from ..vad import VAD from ..vad.events import VADAudioEvent from . import events @@ -159,6 +161,10 @@ def __init__( self._interval_task = None self._callback_executed = False self._track_tasks: Dict[str, asyncio.Task] = {} + # Track metadata: track_id -> (track_type, participant, forwarder) + self._active_video_tracks: Dict[str, tuple[int, Any, Any]] = {} + self._video_forwarders: List[VideoForwarder] = [] + self._current_video_track_id: Optional[str] = None self._connection: Optional[Connection] = None self._audio_track: Optional[OutputAudioTrack] = None self._video_track: Optional[VideoStreamTrack] = None @@ -669,10 +675,48 @@ async def on_track(event: TrackAddedEvent): if not track_id or not track_type: return + # If track is already being processed, just switch to it + if track_id in self._active_video_tracks: + track_type_name = TrackType.Name(track_type) + self.logger.info(f"🎥 Track re-added: {track_type_name} ({track_id}), switching to it") + + if self.realtime_mode and isinstance(self.llm, Realtime): + # Get the existing forwarder and switch to this track + _, _, forwarder = self._active_video_tracks[track_id] + track = self.edge.add_track_subscriber(track_id) + if track and forwarder: + await self.llm._watch_video_track(track, shared_forwarder=forwarder) + self._current_video_track_id = track_id + return + task = asyncio.create_task(self._process_track(track_id, track_type, user)) self._track_tasks[track_id] = task task.add_done_callback(_log_task_exception) + @self.edge.events.subscribe + async def on_track_removed(event: TrackRemovedEvent): + track_id = event.track_id + track_type = event.track_type + if not track_id: + return + + track_type_name = TrackType.Name(track_type) if track_type else "unknown" + self.logger.info(f"🎥 Track removed: {track_type_name} ({track_id})") + + # Cancel the processing task for this track + if track_id in self._track_tasks: + self._track_tasks[track_id].cancel() + self._track_tasks.pop(track_id) + + # Clean up track metadata + self._active_video_tracks.pop(track_id, None) + + # If this was the active track, switch to any other available track + if track_id == self._current_video_track_id and self.realtime_mode and isinstance(self.llm, Realtime): + self.logger.info("🎥 Active video track removed, switching to next available") + self._current_video_track_id = None + await self._switch_to_next_available_track() + async def _reply_to_audio( self, pcm_data: PcmData, participant: Participant ) -> None: @@ -701,125 +745,193 @@ async def _reply_to_audio( self.logger.debug(f"🎵 Processing audio from {participant}") await self.stt.process_audio(pcm_data, participant) - async def _process_track(self, track_id: str, track_type: int, participant): - # TODO: handle CancelledError - # we only process video tracks - if track_type != TrackType.TRACK_TYPE_VIDEO: - return - - # subscribe to the video track - track = self.edge.add_track_subscriber(track_id) - if not track: - self.logger.error(f"Failed to subscribe to {track_id}") + async def _switch_to_next_available_track(self) -> None: + """Switch to any available video track.""" + if not self._active_video_tracks: + self.logger.info("🎥 No video tracks available") + self._current_video_track_id = None return + + # Just pick the first available video track + for track_id, (track_type, participant, forwarder) in self._active_video_tracks.items(): + # Only consider video tracks (camera or screenshare) + if track_type not in (TrackType.TRACK_TYPE_VIDEO, TrackType.TRACK_TYPE_SCREEN_SHARE): + continue + + track_type_name = TrackType.Name(track_type) + self.logger.info(f"🎥 Switching to track: {track_type_name} ({track_id})") + + # Get the track and forwarder + track = self.edge.add_track_subscriber(track_id) + if track and forwarder and isinstance(self.llm, Realtime): + # Send to Realtime provider + await self.llm._watch_video_track(track, shared_forwarder=forwarder) + self._current_video_track_id = track_id + return + else: + self.logger.error(f"Failed to switch to track {track_id}") + + self.logger.warning("🎥 No suitable video tracks found") - # Import VideoForwarder - from ..utils.video_forwarder import VideoForwarder - - # Create a SHARED VideoForwarder for the RAW incoming track - # This prevents multiple recv() calls competing on the same track - raw_forwarder = VideoForwarder( - track, # type: ignore[arg-type] - max_buffer=30, - fps=30, # Max FPS for the producer (individual consumers can throttle down) - name=f"raw_video_forwarder_{track_id}", - ) - await raw_forwarder.start() - self.logger.info("🎥 Created raw VideoForwarder for track %s", track_id) - - # Track forwarders for cleanup - if not hasattr(self, "_video_forwarders"): - self._video_forwarders = [] - self._video_forwarders.append(raw_forwarder) + async def _process_track(self, track_id: str, track_type: int, participant): + raw_forwarder = None + processed_forwarder = None + + try: + # we only process video tracks (camera video or screenshare) + if track_type not in (TrackType.TRACK_TYPE_VIDEO, TrackType.TRACK_TYPE_SCREEN_SHARE): + return - # If Realtime provider supports video, determine which track to send - if self.realtime_mode: - if self._video_track: - # We have a video publisher (e.g., YOLO processor) - # Create a separate forwarder for the PROCESSED video track - self.logger.info( - "🎥 Forwarding PROCESSED video frames to Realtime provider" - ) - processed_forwarder = VideoForwarder( - self._video_track, # type: ignore[arg-type] - max_buffer=30, - fps=30, - name=f"processed_video_forwarder_{track_id}", - ) - await processed_forwarder.start() - self._video_forwarders.append(processed_forwarder) + # subscribe to the video track + track = self.edge.add_track_subscriber(track_id) + if not track: + self.logger.error(f"Failed to subscribe to {track_id}") + return - if isinstance(self.llm, Realtime): - # Send PROCESSED frames with the processed forwarder - await self.llm._watch_video_track( - self._video_track, shared_forwarder=processed_forwarder + # Wrap screenshare tracks to ensure even dimensions for H.264 encoding + if track_type == TrackType.TRACK_TYPE_SCREEN_SHARE: + class _EvenDimensionsTrack(VideoStreamTrack): + def __init__(self, src): + super().__init__() + self.src = src + async def recv(self): + return ensure_even_dimensions(await self.src.recv()) + + track = _EvenDimensionsTrack(track) # type: ignore[arg-type] + + # Create a SHARED VideoForwarder for the RAW incoming track + # This prevents multiple recv() calls competing on the same track + raw_forwarder = VideoForwarder( + track, # type: ignore[arg-type] + max_buffer=30, + fps=30, # Max FPS for the producer (individual consumers can throttle down) + name=f"raw_video_forwarder_{track_id}", + ) + await raw_forwarder.start() + self.logger.info("🎥 Created raw VideoForwarder for track %s", track_id) + + # Track forwarders for cleanup + self._video_forwarders.append(raw_forwarder) + + # Store track metadata + self._active_video_tracks[track_id] = (track_type, participant, raw_forwarder) + + # If Realtime provider supports video, switch to this new track + track_type_name = TrackType.Name(track_type) + + if self.realtime_mode: + if self._video_track: + # We have a video publisher (e.g., YOLO processor) + # Create a separate forwarder for the PROCESSED video track + self.logger.info( + "🎥 Forwarding PROCESSED video frames to Realtime provider" ) - else: - # No video publisher, send raw frames - self.logger.info("🎥 Forwarding RAW video frames to Realtime provider") - if isinstance(self.llm, Realtime): - await self.llm._watch_video_track( - track, shared_forwarder=raw_forwarder + processed_forwarder = VideoForwarder( + self._video_track, # type: ignore[arg-type] + max_buffer=30, + fps=30, + name=f"processed_video_forwarder_{track_id}", + ) + await processed_forwarder.start() + self._video_forwarders.append(processed_forwarder) + + if isinstance(self.llm, Realtime): + # Send PROCESSED frames with the processed forwarder + await self.llm._watch_video_track( + self._video_track, shared_forwarder=processed_forwarder + ) + self._current_video_track_id = track_id + else: + # No video publisher, send raw frames - switch to this new track + self.logger.info(f"🎥 Switching to {track_type_name} track: {track_id}") + if isinstance(self.llm, Realtime): + await self.llm._watch_video_track( + track, shared_forwarder=raw_forwarder + ) + self._current_video_track_id = track_id + + has_image_processors = len(self.image_processors) > 0 + + # video processors - pass the raw forwarder (they process incoming frames) + for processor in self.video_processors: + try: + await processor.process_video( + track, participant.user_id, shared_forwarder=raw_forwarder + ) + except Exception as e: + self.logger.error( + f"Error in video processor {type(processor).__name__}: {e}" ) - hasImageProcessers = len(self.image_processors) > 0 - - # video processors - pass the raw forwarder (they process incoming frames) - for processor in self.video_processors: - try: - await processor.process_video( - track, participant.user_id, shared_forwarder=raw_forwarder - ) - except Exception as e: - self.logger.error( - f"Error in video processor {type(processor).__name__}: {e}" + # Use raw forwarder for image processors - only if there are image processors + if not has_image_processors: + # No image processors, just keep the connection alive + self.logger.info( + "No image processors, video processing handled by video processors only" ) + return - # Use raw forwarder for image processors - only if there are image processors - if not hasImageProcessers: - # No image processors, just keep the connection alive - self.logger.info( - "No image processors, video processing handled by video processors only" - ) - return - - # Initialize error tracking counters - timeout_errors = 0 - consecutive_errors = 0 - - while True: - try: - # Use the raw forwarder instead of competing for track.recv() - video_frame = await raw_forwarder.next_frame(timeout=2.0) - - if video_frame: - # Reset error counts on successful frame processing - timeout_errors = 0 - consecutive_errors = 0 - - if hasImageProcessers: - img = video_frame.to_image() - - for processor in self.image_processors: - try: - await processor.process_image(img, participant.user_id) - except Exception as e: - self.logger.error( - f"Error in image processor {type(processor).__name__}: {e}" - ) - - else: - self.logger.warning("🎥VDP: Received empty frame") - consecutive_errors += 1 + # Initialize error tracking counters + timeout_errors = 0 + consecutive_errors = 0 - except asyncio.TimeoutError: - # Exponential backoff for timeout errors - timeout_errors += 1 - backoff_delay = min(2.0 ** min(timeout_errors, 5), 30.0) - self.logger.debug( - f"🎥VDP: Applying backoff delay: {backoff_delay:.1f}s" - ) - await asyncio.sleep(backoff_delay) + while True: + try: + # Use the raw forwarder instead of competing for track.recv() + video_frame = await raw_forwarder.next_frame(timeout=2.0) + + if video_frame: + # Reset error counts on successful frame processing + timeout_errors = 0 + consecutive_errors = 0 + + if has_image_processors: + img = video_frame.to_image() + + for processor in self.image_processors: + try: + await processor.process_image(img, participant.user_id) + except Exception as e: + self.logger.error( + f"Error in image processor {type(processor).__name__}: {e}" + ) + + else: + self.logger.warning("🎥VDP: Received empty frame") + consecutive_errors += 1 + + except asyncio.TimeoutError: + # Exponential backoff for timeout errors + timeout_errors += 1 + backoff_delay = min(2.0 ** min(timeout_errors, 5), 30.0) + self.logger.debug( + f"🎥VDP: Applying backoff delay: {backoff_delay:.1f}s" + ) + await asyncio.sleep(backoff_delay) + except asyncio.CancelledError: + # Task was cancelled (e.g., track removed) + # Clean up forwarders that were created for this track + self.logger.debug(f"🎥 Cleaning up forwarders for cancelled track {track_id}") + + # Stop and remove the raw forwarder if it was created + if raw_forwarder is not None and hasattr(self, '_video_forwarders'): + if raw_forwarder in self._video_forwarders: + try: + await raw_forwarder.stop() + self._video_forwarders.remove(raw_forwarder) + except Exception as e: + self.logger.error(f"Error stopping raw forwarder: {e}") + + # Stop and remove processed forwarder if it was created + if processed_forwarder is not None and hasattr(self, '_video_forwarders'): + if processed_forwarder in self._video_forwarders: + try: + await processed_forwarder.stop() + self._video_forwarders.remove(processed_forwarder) + except Exception as e: + self.logger.error(f"Error stopping processed forwarder: {e}") + + return async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None: """Handle turn detection events.""" diff --git a/agents-core/vision_agents/core/utils/video_utils.py b/agents-core/vision_agents/core/utils/video_utils.py new file mode 100644 index 00000000..de433459 --- /dev/null +++ b/agents-core/vision_agents/core/utils/video_utils.py @@ -0,0 +1,27 @@ +"""Video frame utilities.""" + +import av + + +def ensure_even_dimensions(frame: av.VideoFrame) -> av.VideoFrame: + """ + Ensure frame has even dimensions for H.264 yuv420p encoding. + + Crops by 1 pixel if width or height is odd. + """ + needs_width_adjust = frame.width % 2 != 0 + needs_height_adjust = frame.height % 2 != 0 + + if not needs_width_adjust and not needs_height_adjust: + return frame + + new_width = frame.width - (1 if needs_width_adjust else 0) + new_height = frame.height - (1 if needs_height_adjust else 0) + + cropped = frame.reformat(width=new_width, height=new_height) + cropped.pts = frame.pts + if frame.time_base is not None: + cropped.time_base = frame.time_base + + return cropped + diff --git a/examples/other_examples/openai_realtime_webrtc/openai_realtime_example.py b/examples/other_examples/openai_realtime_webrtc/openai_realtime_example.py index 9f0883d9..6ce8592d 100644 --- a/examples/other_examples/openai_realtime_webrtc/openai_realtime_example.py +++ b/examples/other_examples/openai_realtime_webrtc/openai_realtime_example.py @@ -16,7 +16,7 @@ from vision_agents.core.agents import Agent from getstream import AsyncStream -logging.basicConfig(level=logging.WARNING) +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) load_dotenv() diff --git a/plugins/getstream/tests/test_getstream_plugin.py b/plugins/getstream/tests/test_getstream_plugin.py index d7d2f461..be1db4d0 100644 --- a/plugins/getstream/tests/test_getstream_plugin.py +++ b/plugins/getstream/tests/test_getstream_plugin.py @@ -1,15 +1,72 @@ -import pytest -from dotenv import load_dotenv +from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import TrackType +from vision_agents.core.events.manager import EventManager +from vision_agents.core.edge.events import TrackAddedEvent, TrackRemovedEvent -load_dotenv() +class TestTrackRepublishing: + """ + Regression test for screenshare republishing bug. + + Bug: When a user stopped and restarted screensharing, the second TrackAddedEvent + was not emitted, so the agent couldn't switch back to the screenshare. + + Fix: stream_edge_transport._on_track_published() now emits TrackAddedEvent even + when the track_key already exists in _track_map. + """ + + async def test_track_events_flow_correctly(self): + """Verify that track events (add -> remove -> add) flow through the event system.""" + event_manager = EventManager() + event_manager.register(TrackAddedEvent) + event_manager.register(TrackRemovedEvent) + + # Collect emitted events + events = [] + + @event_manager.subscribe + async def collect_track_events(event: TrackAddedEvent | TrackRemovedEvent): + events.append(event) + + # Simulate track lifecycle: start -> stop -> start again + track_id = "screenshare-track-1" + track_type = TrackType.TRACK_TYPE_SCREEN_SHARE + + # 1. Start screenshare + event_manager.send(TrackAddedEvent( + plugin_name="getstream", + track_id=track_id, + track_type=track_type, + )) + await event_manager.wait() + + assert len(events) == 1 + assert isinstance(events[0], TrackAddedEvent) + assert events[0].track_id == track_id + + # 2. Stop screenshare + event_manager.send(TrackRemovedEvent( + plugin_name="getstream", + track_id=track_id, + track_type=track_type, + )) + await event_manager.wait() + + assert len(events) == 2 + assert isinstance(events[1], TrackRemovedEvent) + + # 3. Start screenshare again (critical test) + event_manager.send(TrackAddedEvent( + plugin_name="getstream", + track_id=track_id, + track_type=track_type, + )) + await event_manager.wait() + + # Before the fix: The agent would never receive this third event + assert len(events) == 3, "Republishing track should emit TrackAddedEvent" + assert isinstance(events[2], TrackAddedEvent) + assert events[2].track_id == track_id -class TestGetStreamPlugin: - def test_regular(self): - assert True - - # example integration test (run daily on CI) - @pytest.mark.integration - async def test_simple(self): - assert True + # Cleanup + event_manager.unsubscribe(collect_track_events) diff --git a/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py b/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py index 67131e7a..321be544 100644 --- a/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py +++ b/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py @@ -106,11 +106,27 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent): track_key = (user_id, session_id, track_type_int) is_agent_track = user_id == self.agent_user_id + # Skip processing the agent's own tracks - we don't subscribe to them + if is_agent_track: + self.logger.debug(f"Skipping agent's own track: {track_type_int} from {user_id}") + return + # First check if track already exists in map (e.g., from previous unpublish/republish) if track_key in self._track_map: self._track_map[track_key]["published"] = True + track_id = self._track_map[track_key]["track_id"] self.logger.info( - f"Track marked as published (already existed): {track_key}" + f"Track re-published: {track_type_int} from {user_id}, track_id: {track_id}" + ) + + # Emit TrackAddedEvent so agent can switch to this track + self.events.send( + events.TrackAddedEvent( + plugin_name="getstream", + track_id=track_id, + track_type=track_type_int, + user=event.participant, + ) ) return @@ -149,18 +165,15 @@ async def _on_track_published(self, event: sfu_events.TrackPublishedEvent): f"Trackmap published: {track_type_int} from {user_id}, track_id: {track_id} (waited {elapsed:.2f}s)" ) - # Only emit TrackAddedEvent for remote participants, not for agent's own tracks - if not is_agent_track: - # NOW spawn TrackAddedEvent with correct type - self.events.send( - events.TrackAddedEvent( - plugin_name="getstream", - track_id=track_id, - track_type=track_type_int, - user=event.participant, - user_metadata=event.participant, - ) + # Emit TrackAddedEvent with correct type + self.events.send( + events.TrackAddedEvent( + plugin_name="getstream", + track_id=track_id, + track_type=track_type_int, + user=event.participant, ) + ) else: raise TimeoutError( f"Timeout waiting for pending track: {track_type_int} ({expected_kind}) from user {user_id}, " @@ -213,7 +226,6 @@ async def _on_track_removed( track_id=track_id, track_type=track_type_int, user=participant, - user_metadata=participant, ) ) # Mark as unpublished instead of removing @@ -270,13 +282,12 @@ async def on_track(track_id, track_type, user): self.events.silent(events.AudioReceivedEvent) @connection.on("audio") - async def on_audio_received(pcm: PcmData, participant: Participant): + async def on_audio_received(pcm: PcmData | None, participant: Participant): self.events.send( events.AudioReceivedEvent( plugin_name="getstream", pcm_data=pcm, participant=participant, - user_metadata=participant, ) ) @@ -388,10 +399,10 @@ async def open_demo(self, call: Call) -> str: "token": token, "skip_lobby": "true", "user_name": name, - "video_encoder": "vp8", - "bitrate": 12000000, + "video_encoder": "h264", # Use H.264 instead of VP8 for better compatibility + "bitrate": 12000000, "w": 1920, - "h": 1080, + "h": 1080, "channel_type": self.channel_type, } diff --git a/tests/test_utils.py b/tests/test_utils.py index c1c4068b..77dc28a1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,9 @@ import os import tempfile import numpy as np +import av from vision_agents.core.utils.utils import parse_instructions, Instructions +from vision_agents.core.utils.video_utils import ensure_even_dimensions from vision_agents.core.edge.types import PcmData @@ -488,4 +490,40 @@ def test_pcm_data_resample_av_array_shape_fix(self): assert resampled.samples.ndim == 1 -# Shared fixtures for integration tests +class TestEnsureEvenDimensions: + """Test suite for ensure_even_dimensions function.""" + + def test_even_dimensions_unchanged(self): + """Test that frames with even dimensions pass through unchanged.""" + # Create a frame with even dimensions (1920x1080) + frame = av.VideoFrame(width=1920, height=1080, format="yuv420p") + + result = ensure_even_dimensions(frame) + + assert result.width == 1920 + assert result.height == 1080 + + def test_both_dimensions_odd_cropped(self): + """Test that frames with both odd dimensions are cropped.""" + # Create a frame with both odd dimensions (1921x1081) + frame = av.VideoFrame(width=1921, height=1081, format="yuv420p") + + result = ensure_even_dimensions(frame) + + assert result.width == 1920 # Cropped from 1921 + assert result.height == 1080 # Cropped from 1081 + + def test_timing_information_preserved(self): + """Test that pts and time_base are preserved after cropping.""" + from fractions import Fraction + + # Create a frame with timing information + frame = av.VideoFrame(width=1921, height=1081, format="yuv420p") + frame.pts = 12345 + frame.time_base = Fraction(1, 30) + + result = ensure_even_dimensions(frame) + + assert result.pts == 12345 + assert result.time_base == Fraction(1, 30) +