Skip to content

Commit 68cf17a

Browse files
authored
Merge pull request #128 from SAP/bulk_inference_resilience
Improve resilience and performance of do_bulk_inference
2 parents 2c59234 + 5bbf4a2 commit 68cf17a

File tree

12 files changed

+338
-54
lines changed

12 files changed

+338
-54
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ default_language_version:
22
python: python3.7
33
repos:
44
- repo: https://github.com/ambv/black
5-
rev: 21.6b0
5+
rev: 22.3.0
66
hooks:
77
- id: black
88
language_version: python3.7

.travis.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ stages:
4343

4444
language: python
4545
install:
46+
# Fix build on 3.7: https://github.com/pypa/setuptools/issues/3293
47+
- pip3 install 'setuptools==60.9.0;python_version=="3.7"'
4648
- pip3 install -r requirements-test.txt
4749
script: tox
4850

CHANGELOG.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88

9+
## [0.12.0]
10+
11+
### Changed
12+
13+
* `InferenceClient.do_bulk_inference` is now faster due to processing requests in parallel [#128]
14+
* `InferenceClient.do_bulk_inference` is more resilient and handles errors internally. Instead of
15+
raising an Exception if the inference request to the service fails, the `do_bulk_inference` method
16+
will place a special error response object in the returned list. This can be considered a breaking API change,
17+
because the special error response object will have a value of `None` for the `labels` key.
18+
As this project is still versioned below 1.0.0, the breaking API change does not warrant a major version update.
19+
See [#128] for details.
20+
21+
[#128]: https://github.com/SAP/data-attribute-recommendation-python-sdk/pull/128
22+
923
## [0.11.0]
1024

25+
### Added
26+
1127
* Support for reading training jobs using model name in `read_job_by_model_name`
1228

1329
[#124]: https://github.com/SAP/data-attribute-recommendation-python-sdk/pull/124

docs/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# docs
2-
sphinx==2.4.1
3-
sphinx_rtd_theme==0.5.0
2+
sphinx==5.0.2
3+
sphinx_rtd_theme==1.0.0

sap/aibus/dar/client/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,14 @@ class JobNotFound(DARException):
123123
pass
124124

125125

126+
class InvalidWorkerCount(DARException):
127+
"""
128+
Invalid worker_count parameter is specified.
129+
130+
.. versionadded:: 0.12.0
131+
"""
132+
133+
126134
class ModelAlreadyExists(DARException):
127135
"""
128136
Model already exists and must be deleted first.

sap/aibus/dar/client/inference_client.py

Lines changed: 105 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""
22
Client API for the Inference microservice.
33
"""
4-
from typing import List
4+
from concurrent.futures import ThreadPoolExecutor
5+
from typing import List, Union
6+
7+
from requests import RequestException
58

69
from sap.aibus.dar.client.base_client import BaseClientWithSession
10+
from sap.aibus.dar.client.exceptions import DARHTTPException, InvalidWorkerCount
711
from sap.aibus.dar.client.inference_constants import InferencePaths
812
from sap.aibus.dar.client.util.lists import split_list
913

@@ -13,6 +17,8 @@
1317
#: How many labels to predict for a single object by default
1418
TOP_N = 1
1519

20+
# pylint: disable=too-many-arguments
21+
1622

1723
class InferenceClient(BaseClientWithSession):
1824
"""
@@ -73,30 +79,53 @@ def do_bulk_inference(
7379
objects: List[dict],
7480
top_n: int = TOP_N,
7581
retry: bool = True,
76-
) -> List[dict]:
82+
worker_count: int = 4,
83+
) -> List[Union[dict, None]]:
7784
"""
7885
Performs bulk inference for larger collections.
7986
8087
For *objects* collections larger than *LIMIT_OBJECTS_PER_CALL*, splits
8188
the data into several smaller Inference requests.
8289
90+
Requests are executed in parallel.
91+
8392
Returns the aggregated values of the *predictions* of the original API response
84-
as returned by :meth:`create_inference_request`.
93+
as returned by :meth:`create_inference_request`. If one of the inference
94+
requests to the service fails, an artificial prediction object is inserted with
95+
the `labels` key set to `None` for each of the objects in the failing request.
96+
97+
Example of a prediction object which indicates an error:
98+
99+
.. code-block:: python
100+
101+
{
102+
'objectId': 'b5cbcb34-7ab9-4da5-b7ec-654c90757eb9',
103+
'labels': None,
104+
'_sdk_error': 'RequestException: Request Error'
105+
}
106+
107+
In case the `objects` passed to this method do not contain the `objectId` field,
108+
the value is set to `None` in the error prediction object:
109+
110+
.. code-block:: python
111+
112+
{
113+
'objectId': None,
114+
'labels': None,
115+
'_sdk_error': 'RequestException: Request Error'
116+
}
117+
85118
86119
.. note::
87120
88121
This method calls the inference endpoint multiple times to process all data.
89122
For non-trial service instances, each call will incur a cost.
90123
91-
If one of the calls fails, this method will raise an Exception and the
92-
progress will be lost. In this case, all calls until the Exception happened
93-
will be charged.
94-
95-
To reduce the likelihood of a failed request terminating the bulk inference
96-
process, this method will retry failed requests.
124+
To reduce the impact of a failed request, this method will retry failed
125+
requests.
97126
98127
There is a small chance that even retried requests will be charged, e.g.
99-
if a problem occurs with the request on the client side outside of the
128+
if a problem occurs with the request on the client side outside the
100129
control of the service and after the service has processed the request.
101130
To disable `retry` behavior simply pass `retry=False` to the method.
102131
@@ -107,20 +136,80 @@ def do_bulk_inference(
107136
The default for the `retry` parameter changed from `retry=False` to
108137
`retry=True` for increased reliability in day-to-day operations.
109138
139+
.. versionchanged:: 0.12.0
140+
Requests are now executed in parallel with up to four threads.
141+
142+
Errors are now handled in this method instead of raising an exception and
143+
discarding inference results from previous requests. For objects where the
144+
inference request did not succeed, a replacement `dict` object is placed in
145+
the returned `list`.
146+
This `dict` follows the format of the `ObjectPrediction` object sent by the
147+
service. To indicate that this is a client-side generated placeholder, the
148+
`labels` key for all ObjectPrediction dicts of the failed inference request
149+
has value `None`.
150+
A `_sdk_error` key is added with the Exception details.
151+
152+
.. versionadded:: 0.12.0
153+
The `worker_count` parameter allows to fine-tune the number of concurrent
154+
request threads. Set `worker_count` to `1` to disable concurrent execution of
155+
requests.
156+
110157
111158
:param model_name: name of the model used for inference
112159
:param objects: Objects to be classified
113160
:param top_n: How many predictions to return per object
114161
:param retry: whether to retry on errors. Default: True
162+
:param worker_count: maximum number of concurrent requests
163+
:raises: InvalidWorkerCount if worker_count param is incorrect
115164
:return: the aggregated ObjectPrediction dictionaries
116165
"""
117-
result = [] # type: List[dict]
118-
for work_package in split_list(objects, LIMIT_OBJECTS_PER_CALL):
119-
response = self.create_inference_request(
120-
model_name, work_package, top_n=top_n, retry=retry
166+
167+
if worker_count is None:
168+
raise InvalidWorkerCount("worker_count cannot be None!")
169+
170+
if worker_count > 4:
171+
msg = "worker_count too high: %s. Up to 4 allowed." % worker_count
172+
raise InvalidWorkerCount(msg)
173+
174+
if worker_count <= 0:
175+
msg = "worker_count must be greater than 0!"
176+
raise InvalidWorkerCount(msg)
177+
178+
def predict_call(work_package):
179+
try:
180+
response = self.create_inference_request(
181+
model_name, work_package, top_n=top_n, retry=retry
182+
)
183+
return response["predictions"]
184+
except (DARHTTPException, RequestException) as exc:
185+
self.log.warning(
186+
"Caught %s during bulk inference. "
187+
"Setting results to None for this batch!",
188+
exc,
189+
exc_info=True,
190+
)
191+
192+
prediction_error = [
193+
{
194+
"objectId": inference_object.get("objectId", None),
195+
"labels": None,
196+
"_sdk_error": "{}: {}".format(exc.__class__.__name__, str(exc)),
197+
}
198+
for inference_object in work_package
199+
]
200+
return prediction_error
201+
202+
results = []
203+
204+
with ThreadPoolExecutor(max_workers=worker_count) as pool:
205+
results_iterator = pool.map(
206+
predict_call, split_list(objects, LIMIT_OBJECTS_PER_CALL)
121207
)
122-
result.extend(response["predictions"])
123-
return result
208+
209+
for predictions in results_iterator:
210+
results.extend(predictions)
211+
212+
return results
124213

125214
def create_inference_request_with_url(
126215
self,

sap/aibus/dar/client/model_manager_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def create_job(
174174
:param model_template_id: Model template ID for training
175175
:param business_blueprint_id: Business Blueprint template ID for training
176176
:raises CreateTrainingJobFailed: When business_blueprint_id
177-
and model_template_id are provided or when both are not provided
177+
and model_template_id are provided or when both are not provided
178178
:return: newly created Job as dict
179179
"""
180180
self.log.info(

system_tests/workflow/test_end_to_end.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,12 +249,19 @@ def _assert_inference_works(self, inference_client, model_name):
249249
# One object has been classified
250250
assert len(response["predictions"]) == 1
251251

252+
# do_bulk_inference with concurrency
252253
big_to_be_classified = [to_be_classified[0] for _ in range(123)]
253254
response = inference_client.do_bulk_inference(
254255
model_name=model_name, objects=big_to_be_classified
255256
)
256257
assert len(response) == 123
257258

259+
# do_bulk_inference without concurrency
260+
response = inference_client.do_bulk_inference(
261+
model_name=model_name, objects=big_to_be_classified, worker_count=1
262+
)
263+
assert len(response) == 123
264+
258265
url = os.environ["DAR_URL"]
259266
if url[-1] == "/":
260267
url = url[:-1]

tests/sap/aibus/dar/client/test_exceptions.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,38 @@
11
import datetime
22
from unittest.mock import PropertyMock
33

4-
from sap.aibus.dar.client.exceptions import DARHTTPException, ModelAlreadyExists
4+
from sap.aibus.dar.client.exceptions import (
5+
DARHTTPException,
6+
ModelAlreadyExists,
7+
)
58
from tests.sap.aibus.dar.client.test_dar_session import create_mock_response
69

710
# TODO: test __str__
811

912
url = "http://localhost:4321/test/"
1013

14+
correlation_id = "412d84ae-0eb5-4421-863d-956570c2da54"
15+
vcap_request_id = "d9cd7dec-4d74-4a7a-a953-4ca583c8d912"
16+
17+
18+
def create_mock_response_404():
19+
mock_response = create_mock_response()
20+
21+
mock_response.headers["X-Correlation-ID"] = correlation_id
22+
mock_response.headers["X-Vcap-Request-Id"] = vcap_request_id
23+
mock_response.headers["Server"] = "Gunicorn"
24+
mock_response.headers["X-Cf-Routererror"] = "unknown_route"
25+
mock_response.status_code = 404
26+
mock_response.request.method = "GET"
27+
mock_response.reason = b"\xc4\xd6\xdc Not Found"
28+
return mock_response
29+
1130

1231
class TestDARHTTPException:
1332
url = "http://localhost:4321/test/"
1433

1534
def test_basic(self):
16-
mock_response = create_mock_response()
17-
18-
correlation_id = "412d84ae-0eb5-4421-863d-956570c2da54"
19-
mock_response.headers["X-Correlation-ID"] = correlation_id
20-
vcap_request_id = "d9cd7dec-4d74-4a7a-a953-4ca583c8d912"
21-
mock_response.headers["X-Vcap-Request-Id"] = vcap_request_id
22-
mock_response.headers["Server"] = "Gunicorn"
23-
mock_response.headers["X-Cf-Routererror"] = "unknown_route"
24-
mock_response.status_code = 404
25-
mock_response.request.method = "GET"
26-
mock_response.reason = b"\xc4\xd6\xdc Not Found"
35+
mock_response = create_mock_response_404()
2736

2837
exception = DARHTTPException.create_from_response(url, mock_response)
2938

@@ -130,7 +139,6 @@ class TestDARHTTPExceptionReason:
130139
# status line: https://tools.ietf.org/html/rfc7230#section-3.1.2
131140

132141
def test_reason_works_iso8859_1(self):
133-
134142
mock_response = create_mock_response()
135143
# ÄÖÜ encoded as ISO-8859-1
136144
mock_response.reason = b"\xc4\xd6\xdc"

0 commit comments

Comments
 (0)