1+ # Copyright (c) 2025, NVIDIA CORPORATION.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
115import asyncio
216import copy
317import os
721 ConditionalDisaggConfig ,
822 DisaggClusterConfig ,
923 DisaggServerConfig ,
24+ MetadataServerConfig ,
1025 ServerRole ,
1126)
1227from tensorrt_llm .logger import logger
@@ -37,8 +52,9 @@ def __init__(
3752 config : DisaggServerConfig ,
3853 ctx_router : Router ,
3954 gen_router : Router ,
40- client_factory : Callable [[Router , str ], OpenAIClient ] = None ,
55+ client_factory : Callable [[Router , str ], OpenAIClient ],
4156 metadata_server : Optional [JsonDictionary ] = None ,
57+ metadata_config : Optional [MetadataServerConfig ] = None ,
4258 req_timeout_secs : int = 180 ,
4359 server_start_timeout_secs : int = 180 ,
4460 perf_metrics_collector : Optional [DisaggPerfMetricsCollector ] = None ,
@@ -49,6 +65,7 @@ def __init__(
4965 self ._gen_router = gen_router
5066 self ._client_factory = client_factory
5167 self ._metadata_server = metadata_server
68+ self ._metadata_config = metadata_config
5269 self ._req_timeout_secs = req_timeout_secs
5370 self ._server_start_timeout_secs = server_start_timeout_secs
5471 self ._perf_metrics_collector = perf_metrics_collector
@@ -86,11 +103,6 @@ async def openai_chat_completion(
86103 async def _send_disagg_request (
87104 self , request : UCompletionRequest , hooks : Optional [ResponseHooks ] = None
88105 ) -> Union [UCompletionResponse , CompletionResponseGenerator ]:
89- """This is the main disaggregated serving logic:
90- 1. send context request to the context server if ctx is needed, return the context response if gen is not needed
91- 2. build a generation request based on the context response and send it to the generation server if gen is needed,
92- return the generation response
93- """
94106 if hooks :
95107 hooks .on_req_begin (request )
96108 # empty server means client decides which server to use
@@ -104,7 +116,7 @@ async def _send_disagg_request(
104116 if need_ctx :
105117 ctx_req = self ._get_ctx_request (request )
106118 # ctx generator is empty
107- ctx_response = await self ._ctx_client .send_request (ctx_server , ctx_req )
119+ ctx_response = await self ._ctx_client .send_request (ctx_server , ctx_req , hooks )
108120 await self ._verify_ctx_response (ctx_response )
109121 gen_req = self ._get_gen_request (request , ctx_response )
110122 if ctx_response is None or self ._need_gen (ctx_response ):
@@ -206,13 +218,13 @@ async def setup(self) -> None:
206218 await self ._disagg_cluster_manager .watch_workers (on_event = self ._on_worker_event )
207219 logger .info ("Disagg cluster manager started" )
208220 else :
209- if self ._metadata_server :
221+ if self ._metadata_server and self . _metadata_config :
210222 logger .info ("Starting server monitoring via metadata service" )
211223 await self ._ctx_router .start_server_monitoring (
212- self .metadata_server .refresh_interval
224+ self ._metadata_config .refresh_interval
213225 )
214226 await self ._gen_router .start_server_monitoring (
215- self .metadata_server .refresh_interval
227+ self ._metadata_config .refresh_interval
216228 )
217229 await self ._wait_for_servers_ready ()
218230
@@ -274,7 +286,11 @@ async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None:
274286# FIXME: This is a demo to show the basic idea of disagg-service with pre-allocating generation
275287class OpenAIDisaggregatedPreAllocService (OpenAIDisaggregatedService ):
276288 def _need_gen (self , request : UCompletionRequest ) -> bool :
277- return request .max_tokens > 1
289+ if isinstance (request , CompletionRequest ) and request .max_tokens is not None :
290+ return request .max_tokens > 1
291+ if isinstance (request , ChatCompletionRequest ) and request .max_completion_tokens is not None :
292+ return request .max_completion_tokens > 1
293+ return False
278294
279295 async def _send_disagg_request (
280296 self , request : UCompletionRequest , hooks : Optional [ResponseHooks ] = None
@@ -290,23 +306,24 @@ async def _send_disagg_request(
290306 need_gen = self ._need_gen (request )
291307 # send ctx and gen requests in parallel
292308 assert need_gen or need_ctx , "Neither generation nor context is required"
293- with asyncio .TaskGroup () as tg :
294- if need_ctx :
295-
296- async def _run_ctx_task ():
297- # send ctx request and gen request in parallel
298- ctx_req = self ._get_ctx_request (request )
299- ctx_response = await self ._ctx_client .send_completion_request (
300- ctx_server , ctx_req
301- )
302- return ctx_response
303-
304- ctx_task = tg .create_task (_run_ctx_task ())
305- if need_gen :
306- gen_task = tg .create_task (
307- self ._gen_client .send_completion_request (gen_server , request , hooks )
308- )
309+ gen_task = None
310+ ctx_task = None
311+ tasks = []
312+
313+ async def _run_ctx_task ():
314+ # send ctx request and gen request in parallel
315+ ctx_req = self ._get_ctx_request (request )
316+ ctx_response = await self ._ctx_client .send_request (ctx_server , ctx_req , hooks )
317+ return ctx_response
318+
319+ if need_ctx :
320+ ctx_task = asyncio .create_task (_run_ctx_task ())
321+ if need_gen :
322+ gen_task = asyncio .create_task (
323+ self ._gen_client .send_request (gen_server , request , hooks )
324+ )
325+ tasks .append (gen_task )
326+ await asyncio .gather (* tasks )
309327 if need_gen :
310328 return gen_task .result ()
311- else :
312- return ctx_task .result ()
329+ return ctx_task .result ()
0 commit comments