88from vllm .v1 .executor .abstract import Executor
99from vllm .v1 .request import Request
1010
11- from tpu_inference .core .adapters import (VllmConfigAdapter , VllmEngineAdapter ,
12- VllmRequestAdapter )
1311from tpu_inference .core .core_tpu import (DisaggEngineCore ,
1412 DisaggEngineCoreProc ,
1513 _DisaggOrchestrator )
16- from tpu_inference .interfaces .config import IConfig
17- from tpu_inference .interfaces .engine import IEngineCore
1814
1915
2016class TestDisaggEngineCore (unittest .TestCase ):
@@ -67,13 +63,11 @@ def test_initialization(self):
6763
6864 self .mock_orchestrator .assert_called_once ()
6965 args , kwargs = self .mock_orchestrator .call_args
70- self .assertIsInstance (kwargs ['config' ], VllmConfigAdapter )
71- self .assertEqual (kwargs ['config' ]. vllm_config , self .mock_vllm_config )
66+ self .assertIsInstance (kwargs ['config' ], VllmConfig )
67+ self .assertEqual (kwargs ['config' ], self .mock_vllm_config )
7268 self .assertEqual (kwargs ['output_queue' ], engine .output_queue )
7369 self .assertEqual (len (kwargs ['prefill_engines' ]), 1 )
74- self .assertIsInstance (kwargs ['prefill_engines' ][0 ], VllmEngineAdapter )
7570 self .assertEqual (len (kwargs ['decode_engines' ]), 1 )
76- self .assertIsInstance (kwargs ['decode_engines' ][0 ], VllmEngineAdapter )
7771 self .assertEqual (kwargs ['prefill_slice_sizes' ], (4 , ))
7872 self .assertEqual (kwargs ['decode_slice_sizes' ], (2 , ))
7973
@@ -94,13 +88,12 @@ def test_add_request(self):
9488
9589 self .mock_orchestrator .return_value .add_request .assert_called_once ()
9690 # Get the argument passed to add_request
97- passed_request_adapter = self .mock_orchestrator .return_value .add_request .call_args [
91+ passed_request = self .mock_orchestrator .return_value .add_request .call_args [
9892 0 ][0 ]
9993
100- # Assert it's the correct type and wraps the correct underlying request
101- self .assertIsInstance (passed_request_adapter , VllmRequestAdapter )
102- self .assertIsInstance (passed_request_adapter .vllm_request , Request )
103- self .assertEqual (passed_request_adapter .request_id , "test_req" )
94+ # Assert it's the correct type (the Request directly)
95+ self .assertIsInstance (passed_request , Request )
96+ self .assertEqual (passed_request .request_id , "test_req" )
10497
10598 def test_shutdown (self ):
10699 """Tests that the adapter correctly delegates shutdown to the orchestrator."""
@@ -204,13 +197,11 @@ def test_initialization(self):
204197
205198 self .mock_orchestrator .assert_called_once ()
206199 args , kwargs = self .mock_orchestrator .call_args
207- self .assertIsInstance (kwargs ['config' ], VllmConfigAdapter )
208- self .assertEqual (kwargs ['config' ]. vllm_config , self .mock_vllm_config )
200+ self .assertIsInstance (kwargs ['config' ], VllmConfig )
201+ self .assertEqual (kwargs ['config' ], self .mock_vllm_config )
209202 self .assertEqual (kwargs ['output_queue' ], proc .output_queue )
210203 self .assertEqual (len (kwargs ['prefill_engines' ]), 1 )
211- self .assertIsInstance (kwargs ['prefill_engines' ][0 ], VllmEngineAdapter )
212204 self .assertEqual (len (kwargs ['decode_engines' ]), 1 )
213- self .assertIsInstance (kwargs ['decode_engines' ][0 ], VllmEngineAdapter )
214205 self .assertEqual (kwargs ['prefill_slice_sizes' ], (4 , ))
215206 self .assertEqual (kwargs ['decode_slice_sizes' ], (2 , ))
216207
@@ -239,13 +230,12 @@ def test_add_request(self):
239230
240231 self .mock_orchestrator .return_value .add_request .assert_called_once ()
241232 # Get the argument passed to add_request
242- passed_request_adapter = self .mock_orchestrator .return_value .add_request .call_args [
233+ passed_request = self .mock_orchestrator .return_value .add_request .call_args [
243234 0 ][0 ]
244235
245- # Assert it's the correct type and wraps the correct underlying request
246- self .assertIsInstance (passed_request_adapter , VllmRequestAdapter )
247- self .assertIsInstance (passed_request_adapter .vllm_request , Request )
248- self .assertEqual (passed_request_adapter .request_id , "test_req" )
236+ # Assert it's the correct type (the Request directly)
237+ self .assertIsInstance (passed_request , Request )
238+ self .assertEqual (passed_request .request_id , "test_req" )
249239
250240 def test_shutdown (self ):
251241 """Tests that the adapter correctly delegates shutdown to the orchestrator."""
@@ -321,15 +311,15 @@ def test_handle_client_request_utility(self):
321311class TestDisaggOrchestrator (unittest .TestCase ):
322312
323313 def setUp (self ):
324- self .mock_config = MagicMock (spec = IConfig )
314+ self .mock_config = MagicMock (spec = VllmConfig )
325315 self .mock_config .scheduler_config = MagicMock ()
326316 self .mock_config .scheduler_config .max_num_seqs = 16
327317 self .mock_config .cache_config = MagicMock ()
328318 self .mock_config .cache_config .block_size = 5
329319
330320 self .mock_output_queue = MagicMock ()
331- self .mock_prefill_engine = MagicMock (spec = IEngineCore )
332- self .mock_decode_engine = MagicMock (spec = IEngineCore )
321+ self .mock_prefill_engine = MagicMock ()
322+ self .mock_decode_engine = MagicMock ()
333323
334324 # The orchestrator accesses the scheduler on the engine.
335325 self .mock_prefill_engine .scheduler = MagicMock ()
@@ -374,7 +364,7 @@ def test_add_request(self):
374364 decode_slice_sizes = (2 , ),
375365 )
376366 mock_request = MagicMock ()
377- mock_request .vllm_request . request_id = "test_req"
367+ mock_request .request_id = "test_req"
378368
379369 orchestrator .add_request (mock_request )
380370
@@ -469,7 +459,7 @@ def test_decode_logic(self):
469459
470460 # Mock request
471461 mock_request = MagicMock ()
472- mock_request .vllm_request . num_computed_tokens = 10
462+ mock_request .num_computed_tokens = 10
473463 orchestrator ._requests ["test_req" ] = mock_request
474464
475465 # Mock scheduler and model runner states for the loop condition
0 commit comments