88import pytest
99from requests import RequestException , Timeout
1010
11- from sap .aibus .dar .client .exceptions import DARHTTPException
11+ from sap .aibus .dar .client .exceptions import DARHTTPException , InvalidWorkerCount
1212from sap .aibus .dar .client .inference_client import InferenceClient
1313from tests .sap .aibus .dar .client .test_data_manager_client import (
1414 AbstractDARClientConstruction ,
@@ -180,7 +180,11 @@ def _assert_bulk_inference_works(
180180 retry_kwarg ["retry" ] = retry_flag
181181
182182 response = inference_client .do_bulk_inference (
183- model_name = "test-model" , objects = many_objects , top_n = 4 , ** retry_kwarg
183+ model_name = "test-model" ,
184+ objects = many_objects ,
185+ top_n = 4 ,
186+ worker_count = 1 , # Disable concurrency to make tests deterministic.
187+ ** retry_kwarg ,
184188 )
185189
186190 # The return value is the concatenation of all 'predictions' of the individual
@@ -348,6 +352,7 @@ def test_bulk_inference_error_no_object_ids(
348352 model_name = "test-model" ,
349353 objects = inference_objects ,
350354 top_n = 4 ,
355+ worker_count = 1 , # disable concurrency to make tests deterministic
351356 )
352357 expected_error_response = {
353358 "objectId" : None ,
@@ -368,3 +373,33 @@ def test_bulk_inference_error_no_object_ids(
368373 )
369374
370375 assert response == expected_response
376+
377+ def test_worker_count_validation (self , inference_client : InferenceClient ):
378+
379+ many_objects = [self .objects ()[0 ] for _ in range (75 )]
380+
381+ with pytest .raises (InvalidWorkerCount ) as context :
382+ inference_client .do_bulk_inference (
383+ model_name = "test-model" , objects = many_objects , worker_count = 5
384+ )
385+ assert "worker_count too high: 5. Up to 4 allowed." in str (context .value )
386+
387+ with pytest .raises (InvalidWorkerCount ) as context :
388+ inference_client .do_bulk_inference (
389+ model_name = "test-model" , objects = many_objects , worker_count = 0
390+ )
391+ assert "worker_count must be greater than 0" in str (context .value )
392+
393+ with pytest .raises (InvalidWorkerCount ) as context :
394+ inference_client .do_bulk_inference (
395+ model_name = "test-model" , objects = many_objects , worker_count = - 1
396+ )
397+ assert "worker_count must be greater than 0" in str (context .value )
398+
399+ with pytest .raises (InvalidWorkerCount ) as context :
400+ inference_client .do_bulk_inference (
401+ model_name = "test-model" ,
402+ objects = many_objects ,
403+ worker_count = None ,
404+ )
405+ assert "worker_count cannot be None" in str (context .value )
0 commit comments