@@ -41,6 +41,7 @@ async def handle_sse(request):
4141from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
4242from pydantic import ValidationError
4343from sse_starlette import EventSourceResponse
44+ from starlette .background import BackgroundTask
4445from starlette .requests import Request
4546from starlette .responses import Response
4647from starlette .types import Receive , Scope , Send
@@ -78,6 +79,18 @@ def __init__(self, endpoint: str) -> None:
7879 self ._read_stream_writers = {}
7980 logger .debug (f"SseServerTransport initialized with endpoint: { endpoint } " )
8081
82+ async def _remove_stream_writer (self , session_id : UUID ) -> None :
83+ """
84+ Remove the SSE session with the given session ID.
85+ """
86+ logger .debug (f"Remove SSE session with ID: { session_id } " )
87+ writer = self ._read_stream_writers .pop (session_id , None )
88+ if writer :
89+ await writer .aclose ()
90+ logger .debug (f"Closed SSE session with ID: { session_id } " )
91+ else :
92+ logger .warning (f"Session ID { session_id } not found for removal" )
93+
8194 @asynccontextmanager
8295 async def connect_sse (self , scope : Scope , receive : Receive , send : Send ):
8396 if scope ["type" ] != "http" :
@@ -119,10 +132,11 @@ async def sse_writer():
119132 ),
120133 }
121134 )
122-
135+ background_task = BackgroundTask ( self . _remove_stream_writer , session_id )
123136 async with anyio .create_task_group () as tg :
124137 response = EventSourceResponse (
125- content = sse_stream_reader , data_sender_callable = sse_writer
138+ content = sse_stream_reader , data_sender_callable = sse_writer ,
139+ background = background_task ,
126140 )
127141 logger .debug ("Starting SSE response task" )
128142 tg .start_soon (response , scope , receive , send )
0 commit comments