Skip to content

Commit 5bbf4a2

Browse files
committed
do_bulk_inference: add worker_count parameter
This is mainly useful to fix the tests which rely on the mocks being called in a certain order. One of the tests supports concurrency by mocking in a better way, but this was not feasible for the other tests. This commit also updates the documentation build tools to the latest version to fix the documentation build on my local machine.
1 parent f0d4169 commit 5bbf4a2

File tree

6 files changed

+81
-7
lines changed

6 files changed

+81
-7
lines changed

docs/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# docs
2-
sphinx==2.4.1
3-
sphinx_rtd_theme==0.5.0
2+
sphinx==5.0.2
3+
sphinx_rtd_theme==1.0.0

sap/aibus/dar/client/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,14 @@ class JobNotFound(DARException):
123123
pass
124124

125125

126+
class InvalidWorkerCount(DARException):
127+
"""
128+
Invalid worker_count parameter is specified.
129+
130+
.. versionadded:: 0.12.0
131+
"""
132+
133+
126134
class ModelAlreadyExists(DARException):
127135
"""
128136
Model already exists and must be deleted first.

sap/aibus/dar/client/inference_client.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from requests import RequestException
88

99
from sap.aibus.dar.client.base_client import BaseClientWithSession
10-
from sap.aibus.dar.client.exceptions import DARHTTPException
10+
from sap.aibus.dar.client.exceptions import DARHTTPException, InvalidWorkerCount
1111
from sap.aibus.dar.client.inference_constants import InferencePaths
1212
from sap.aibus.dar.client.util.lists import split_list
1313

@@ -17,6 +17,8 @@
1717
#: How many labels to predict for a single object by default
1818
TOP_N = 1
1919

20+
# pylint: disable=too-many-arguments
21+
2022

2123
class InferenceClient(BaseClientWithSession):
2224
"""
@@ -77,6 +79,7 @@ def do_bulk_inference(
7779
objects: List[dict],
7880
top_n: int = TOP_N,
7981
retry: bool = True,
82+
worker_count: int = 4,
8083
) -> List[Union[dict, None]]:
8184
"""
8285
Performs bulk inference for larger collections.
@@ -146,14 +149,32 @@ def do_bulk_inference(
146149
has value `None`.
147150
A `_sdk_error` key is added with the Exception details.
148151
152+
.. versionadded:: 0.12.0
153+
The `worker_count` parameter allows to fine-tune the number of concurrent
154+
request threads. Set `worker_count` to `1` to disable concurrent execution of
155+
requests.
156+
149157
150158
:param model_name: name of the model used for inference
151159
:param objects: Objects to be classified
152160
:param top_n: How many predictions to return per object
153161
:param retry: whether to retry on errors. Default: True
162+
:param worker_count: maximum number of concurrent requests
163+
:raises: InvalidWorkerCount if worker_count param is incorrect
154164
:return: the aggregated ObjectPrediction dictionaries
155165
"""
156166

167+
if worker_count is None:
168+
raise InvalidWorkerCount("worker_count cannot be None!")
169+
170+
if worker_count > 4:
171+
msg = "worker_count too high: %s. Up to 4 allowed." % worker_count
172+
raise InvalidWorkerCount(msg)
173+
174+
if worker_count <= 0:
175+
msg = "worker_count must be greater than 0!"
176+
raise InvalidWorkerCount(msg)
177+
157178
def predict_call(work_package):
158179
try:
159180
response = self.create_inference_request(
@@ -180,7 +201,7 @@ def predict_call(work_package):
180201

181202
results = []
182203

183-
with ThreadPoolExecutor(max_workers=4) as pool:
204+
with ThreadPoolExecutor(max_workers=worker_count) as pool:
184205
results_iterator = pool.map(
185206
predict_call, split_list(objects, LIMIT_OBJECTS_PER_CALL)
186207
)

system_tests/workflow/test_end_to_end.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,12 +249,19 @@ def _assert_inference_works(self, inference_client, model_name):
249249
# One object has been classified
250250
assert len(response["predictions"]) == 1
251251

252+
# do_bulk_inference with concurrency
252253
big_to_be_classified = [to_be_classified[0] for _ in range(123)]
253254
response = inference_client.do_bulk_inference(
254255
model_name=model_name, objects=big_to_be_classified
255256
)
256257
assert len(response) == 123
257258

259+
# do_bulk_inference without concurrency
260+
response = inference_client.do_bulk_inference(
261+
model_name=model_name, objects=big_to_be_classified, worker_count=1
262+
)
263+
assert len(response) == 123
264+
258265
url = os.environ["DAR_URL"]
259266
if url[-1] == "/":
260267
url = url[:-1]

tests/sap/aibus/dar/client/test_exceptions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import datetime
22
from unittest.mock import PropertyMock
33

4-
from sap.aibus.dar.client.exceptions import DARHTTPException, ModelAlreadyExists
4+
from sap.aibus.dar.client.exceptions import (
5+
DARHTTPException,
6+
ModelAlreadyExists,
7+
)
58
from tests.sap.aibus.dar.client.test_dar_session import create_mock_response
69

710
# TODO: test __str__

tests/sap/aibus/dar/client/test_inference_client.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
from requests import RequestException, Timeout
1010

11-
from sap.aibus.dar.client.exceptions import DARHTTPException
11+
from sap.aibus.dar.client.exceptions import DARHTTPException, InvalidWorkerCount
1212
from sap.aibus.dar.client.inference_client import InferenceClient
1313
from tests.sap.aibus.dar.client.test_data_manager_client import (
1414
AbstractDARClientConstruction,
@@ -180,7 +180,11 @@ def _assert_bulk_inference_works(
180180
retry_kwarg["retry"] = retry_flag
181181

182182
response = inference_client.do_bulk_inference(
183-
model_name="test-model", objects=many_objects, top_n=4, **retry_kwarg
183+
model_name="test-model",
184+
objects=many_objects,
185+
top_n=4,
186+
worker_count=1, # Disable concurrency to make tests deterministic.
187+
**retry_kwarg,
184188
)
185189

186190
# The return value is the concatenation of all 'predictions' of the individual
@@ -348,6 +352,7 @@ def test_bulk_inference_error_no_object_ids(
348352
model_name="test-model",
349353
objects=inference_objects,
350354
top_n=4,
355+
worker_count=1, # disable concurrency to make tests deterministic
351356
)
352357
expected_error_response = {
353358
"objectId": None,
@@ -368,3 +373,33 @@ def test_bulk_inference_error_no_object_ids(
368373
)
369374

370375
assert response == expected_response
376+
377+
def test_worker_count_validation(self, inference_client: InferenceClient):
378+
379+
many_objects = [self.objects()[0] for _ in range(75)]
380+
381+
with pytest.raises(InvalidWorkerCount) as context:
382+
inference_client.do_bulk_inference(
383+
model_name="test-model", objects=many_objects, worker_count=5
384+
)
385+
assert "worker_count too high: 5. Up to 4 allowed." in str(context.value)
386+
387+
with pytest.raises(InvalidWorkerCount) as context:
388+
inference_client.do_bulk_inference(
389+
model_name="test-model", objects=many_objects, worker_count=0
390+
)
391+
assert "worker_count must be greater than 0" in str(context.value)
392+
393+
with pytest.raises(InvalidWorkerCount) as context:
394+
inference_client.do_bulk_inference(
395+
model_name="test-model", objects=many_objects, worker_count=-1
396+
)
397+
assert "worker_count must be greater than 0" in str(context.value)
398+
399+
with pytest.raises(InvalidWorkerCount) as context:
400+
inference_client.do_bulk_inference(
401+
model_name="test-model",
402+
objects=many_objects,
403+
worker_count=None,
404+
)
405+
assert "worker_count cannot be None" in str(context.value)

0 commit comments

Comments
 (0)