Skip to content

Commit 31da951

Browse files
committed
do_bulk_inference: add worker_count parameter
This is mainly useful to fix the tests which rely on the mocks being called in a certain order. One of the tests supports concurrency by mocking in a better way, but this was not feasible for the other tests.
1 parent f0d4169 commit 31da951

File tree

6 files changed

+81
-7
lines changed

6 files changed

+81
-7
lines changed

docs/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# docs
2-
sphinx==2.4.1
3-
sphinx_rtd_theme==0.5.0
2+
sphinx==5.0.2
3+
sphinx_rtd_theme==1.0.0

sap/aibus/dar/client/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,14 @@ class JobNotFound(DARException):
123123
pass
124124

125125

126+
class InvalidWorkerCount(DARException):
127+
"""
128+
Invalid worker_count parameter is specified.
129+
130+
.. versionadded:: 0.12.0
131+
"""
132+
133+
126134
class ModelAlreadyExists(DARException):
127135
"""
128136
Model already exists and must be deleted first.

sap/aibus/dar/client/inference_client.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from requests import RequestException
88

99
from sap.aibus.dar.client.base_client import BaseClientWithSession
10-
from sap.aibus.dar.client.exceptions import DARHTTPException
10+
from sap.aibus.dar.client.exceptions import DARHTTPException, InvalidWorkerCount
1111
from sap.aibus.dar.client.inference_constants import InferencePaths
1212
from sap.aibus.dar.client.util.lists import split_list
1313

@@ -17,6 +17,8 @@
1717
#: How many labels to predict for a single object by default
1818
TOP_N = 1
1919

20+
# pylint: disable=too-many-arguments
21+
2022

2123
class InferenceClient(BaseClientWithSession):
2224
"""
@@ -77,6 +79,7 @@ def do_bulk_inference(
7779
objects: List[dict],
7880
top_n: int = TOP_N,
7981
retry: bool = True,
82+
worker_count: int = 4,
8083
) -> List[Union[dict, None]]:
8184
"""
8285
Performs bulk inference for larger collections.
@@ -146,14 +149,32 @@ def do_bulk_inference(
146149
has value `None`.
147150
A `_sdk_error` key is added with the Exception details.
148151
152+
.. versionadded:: 0.12.0
153+
The `worker_count` parameter allows to fine-tune the number of concurrent
154+
request threads. Set `worker_count` to `1` to disable concurrent execution of
155+
requests.
156+
149157
150158
:param model_name: name of the model used for inference
151159
:param objects: Objects to be classified
152160
:param top_n: How many predictions to return per object
153161
:param retry: whether to retry on errors. Default: True
162+
:param worker_count: maximum number of concurrent requests
163+
:raises: InvalidWorkerCount if worker_count param is incorrect
154164
:return: the aggregated ObjectPrediction dictionaries
155165
"""
156166

167+
if worker_count is None:
168+
raise InvalidWorkerCount("worker_count cannot be None!")
169+
170+
if worker_count > 4:
171+
msg = "worker_count too high: %s. Up to 4 allowed." % worker_count
172+
raise InvalidWorkerCount(msg)
173+
174+
if worker_count <= 0:
175+
msg = "worker_count must be greater than 0!"
176+
raise InvalidWorkerCount(msg)
177+
157178
def predict_call(work_package):
158179
try:
159180
response = self.create_inference_request(
@@ -180,7 +201,7 @@ def predict_call(work_package):
180201

181202
results = []
182203

183-
with ThreadPoolExecutor(max_workers=4) as pool:
204+
with ThreadPoolExecutor(max_workers=worker_count) as pool:
184205
results_iterator = pool.map(
185206
predict_call, split_list(objects, LIMIT_OBJECTS_PER_CALL)
186207
)

system_tests/workflow/test_end_to_end.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,12 +249,19 @@ def _assert_inference_works(self, inference_client, model_name):
249249
# One object has been classified
250250
assert len(response["predictions"]) == 1
251251

252+
# do_bulk_inference with concurrency
252253
big_to_be_classified = [to_be_classified[0] for _ in range(123)]
253254
response = inference_client.do_bulk_inference(
254255
model_name=model_name, objects=big_to_be_classified
255256
)
256257
assert len(response) == 123
257258

259+
# do_bulk_inference without concurrency
260+
response = inference_client.do_bulk_inference(
261+
model_name=model_name, objects=big_to_be_classified, worker_count=1
262+
)
263+
assert len(response) == 123
264+
258265
url = os.environ["DAR_URL"]
259266
if url[-1] == "/":
260267
url = url[:-1]

tests/sap/aibus/dar/client/test_exceptions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import datetime
22
from unittest.mock import PropertyMock
33

4-
from sap.aibus.dar.client.exceptions import DARHTTPException, ModelAlreadyExists
4+
from sap.aibus.dar.client.exceptions import (
5+
DARHTTPException,
6+
ModelAlreadyExists,
7+
)
58
from tests.sap.aibus.dar.client.test_dar_session import create_mock_response
69

710
# TODO: test __str__

tests/sap/aibus/dar/client/test_inference_client.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
from requests import RequestException, Timeout
1010

11-
from sap.aibus.dar.client.exceptions import DARHTTPException
11+
from sap.aibus.dar.client.exceptions import DARHTTPException, InvalidWorkerCount
1212
from sap.aibus.dar.client.inference_client import InferenceClient
1313
from tests.sap.aibus.dar.client.test_data_manager_client import (
1414
AbstractDARClientConstruction,
@@ -180,7 +180,11 @@ def _assert_bulk_inference_works(
180180
retry_kwarg["retry"] = retry_flag
181181

182182
response = inference_client.do_bulk_inference(
183-
model_name="test-model", objects=many_objects, top_n=4, **retry_kwarg
183+
model_name="test-model",
184+
objects=many_objects,
185+
top_n=4,
186+
worker_count=1, # Disable concurrency to make tests deterministic.
187+
**retry_kwarg,
184188
)
185189

186190
# The return value is the concatenation of all 'predictions' of the individual
@@ -348,6 +352,7 @@ def test_bulk_inference_error_no_object_ids(
348352
model_name="test-model",
349353
objects=inference_objects,
350354
top_n=4,
355+
worker_count=1, # disable concurrency to make tests deterministic
351356
)
352357
expected_error_response = {
353358
"objectId": None,
@@ -368,3 +373,33 @@ def test_bulk_inference_error_no_object_ids(
368373
)
369374

370375
assert response == expected_response
376+
377+
def test_worker_count_validation(self, inference_client: InferenceClient):
378+
379+
many_objects = [self.objects()[0] for _ in range(75)]
380+
381+
with pytest.raises(InvalidWorkerCount) as context:
382+
inference_client.do_bulk_inference(
383+
model_name="test-model", objects=many_objects, worker_count=5
384+
)
385+
assert "worker_count too high: 5. Up to 4 allowed." in str(context.value)
386+
387+
with pytest.raises(InvalidWorkerCount) as context:
388+
inference_client.do_bulk_inference(
389+
model_name="test-model", objects=many_objects, worker_count=0
390+
)
391+
assert "worker_count must be greater than 0" in str(context.value)
392+
393+
with pytest.raises(InvalidWorkerCount) as context:
394+
inference_client.do_bulk_inference(
395+
model_name="test-model", objects=many_objects, worker_count=-1
396+
)
397+
assert "worker_count must be greater than 0" in str(context.value)
398+
399+
with pytest.raises(InvalidWorkerCount) as context:
400+
inference_client.do_bulk_inference(
401+
model_name="test-model",
402+
objects=many_objects,
403+
worker_count=None,
404+
)
405+
assert "worker_count cannot be None" in str(context.value)

0 commit comments

Comments
 (0)