Skip to content

Commit 7941f4c

Browse files
committed
Improve resilience and performance of do_bulk_inference
In case of errors, the `InferenceClient.do_bulk_inference` method will now return `None` for the affected objects instead of aborting the entire bulk inference operation (and discarding any successfully processed objects). Fixes issue #68 The fix for #68 is different than what is described in #68. Instead of using a generator based approach which will require the SDK consumer to implement the error handling themselves, the SDK itself now handles the errors. The downside of not using a generator is a larger memory footprint to accumulate the results in a list. As an alternative, we can consider using a generator to either yield the successfully processed inference results or the list containing `None`. This approach will save memory. Additionally, this commit introduces parallel processing in `InferenceClient.do_bulk_inference`. This will greatly improve performance. Due to the non-lazy implementation of `ThreadPoolProcessor.map`, this increases memory usage slightly ([cpython issue #74028]) [cpython issue #74028]: python/cpython#74028
1 parent 018da32 commit 7941f4c

File tree

4 files changed

+89
-26
lines changed

4 files changed

+89
-26
lines changed

Diff for: .pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ default_language_version:
22
python: python3.7
33
repos:
44
- repo: https://github.com/ambv/black
5-
rev: 21.6b0
5+
rev: 22.3.0
66
hooks:
77
- id: black
88
language_version: python3.7

Diff for: sap/aibus/dar/client/inference_client.py

+33-13
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""
22
Client API for the Inference microservice.
33
"""
4-
from typing import List
4+
from concurrent.futures import ThreadPoolExecutor
5+
from typing import List, Union
6+
7+
from requests import RequestException
58

69
from sap.aibus.dar.client.base_client import BaseClientWithSession
10+
from sap.aibus.dar.client.exceptions import DARHTTPException
711
from sap.aibus.dar.client.inference_constants import InferencePaths
812
from sap.aibus.dar.client.util.lists import split_list
913

@@ -73,7 +77,7 @@ def do_bulk_inference(
7377
objects: List[dict],
7478
top_n: int = TOP_N,
7579
retry: bool = True,
76-
) -> List[dict]:
80+
) -> List[Union[dict, None]]:
7781
"""
7882
Performs bulk inference for larger collections.
7983
@@ -88,15 +92,11 @@ def do_bulk_inference(
8892
This method calls the inference endpoint multiple times to process all data.
8993
For non-trial service instances, each call will incur a cost.
9094
91-
If one of the calls fails, this method will raise an Exception and the
92-
progress will be lost. In this case, all calls until the Exception happened
93-
will be charged.
94-
9595
To reduce the likelihood of a failed request terminating the bulk inference
9696
process, this method will retry failed requests.
9797
9898
There is a small chance that even retried requests will be charged, e.g.
99-
if a problem occurs with the request on the client side outside of the
99+
if a problem occurs with the request on the client side outside the
100100
control of the service and after the service has processed the request.
101101
To disable `retry` behavior simply pass `retry=False` to the method.
102102
@@ -114,10 +114,30 @@ def do_bulk_inference(
114114
:param retry: whether to retry on errors. Default: True
115115
:return: the aggregated ObjectPrediction dictionaries
116116
"""
117-
result = [] # type: List[dict]
118-
for work_package in split_list(objects, LIMIT_OBJECTS_PER_CALL):
119-
response = self.create_inference_request(
120-
model_name, work_package, top_n=top_n, retry=retry
117+
118+
def predict_call(work_package):
119+
try:
120+
response = self.create_inference_request(
121+
model_name, work_package, top_n=top_n, retry=retry
122+
)
123+
return response["predictions"]
124+
except (DARHTTPException, RequestException) as exc:
125+
self.log.warning(
126+
"Caught %s during bulk inference. "
127+
"Setting results to None for this batch!",
128+
exc,
129+
exc_info=True,
130+
)
131+
return [None for _ in range(len(work_package))]
132+
133+
results = []
134+
135+
with ThreadPoolExecutor(max_workers=4) as pool:
136+
results_iterator = pool.map(
137+
predict_call, split_list(objects, LIMIT_OBJECTS_PER_CALL)
121138
)
122-
result.extend(response["predictions"])
123-
return result
139+
140+
for predictions in results_iterator:
141+
results.extend(predictions)
142+
143+
return results

Diff for: tests/sap/aibus/dar/client/test_exceptions.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,28 @@
88

99
url = "http://localhost:4321/test/"
1010

11+
correlation_id = "412d84ae-0eb5-4421-863d-956570c2da54"
12+
vcap_request_id = "d9cd7dec-4d74-4a7a-a953-4ca583c8d912"
13+
14+
15+
def create_mock_response_404():
16+
mock_response = create_mock_response()
17+
18+
mock_response.headers["X-Correlation-ID"] = correlation_id
19+
mock_response.headers["X-Vcap-Request-Id"] = vcap_request_id
20+
mock_response.headers["Server"] = "Gunicorn"
21+
mock_response.headers["X-Cf-Routererror"] = "unknown_route"
22+
mock_response.status_code = 404
23+
mock_response.request.method = "GET"
24+
mock_response.reason = b"\xc4\xd6\xdc Not Found"
25+
return mock_response
26+
1127

1228
class TestDARHTTPException:
1329
url = "http://localhost:4321/test/"
1430

1531
def test_basic(self):
16-
mock_response = create_mock_response()
17-
18-
correlation_id = "412d84ae-0eb5-4421-863d-956570c2da54"
19-
mock_response.headers["X-Correlation-ID"] = correlation_id
20-
vcap_request_id = "d9cd7dec-4d74-4a7a-a953-4ca583c8d912"
21-
mock_response.headers["X-Vcap-Request-Id"] = vcap_request_id
22-
mock_response.headers["Server"] = "Gunicorn"
23-
mock_response.headers["X-Cf-Routererror"] = "unknown_route"
24-
mock_response.status_code = 404
25-
mock_response.request.method = "GET"
26-
mock_response.reason = b"\xc4\xd6\xdc Not Found"
32+
mock_response = create_mock_response_404()
2733

2834
exception = DARHTTPException.create_from_response(url, mock_response)
2935

@@ -130,7 +136,6 @@ class TestDARHTTPExceptionReason:
130136
# status line: https://tools.ietf.org/html/rfc7230#section-3.1.2
131137

132138
def test_reason_works_iso8859_1(self):
133-
134139
mock_response = create_mock_response()
135140
# ÄÖÜ encoded as ISO-8859-1
136141
mock_response.reason = b"\xc4\xd6\xdc"

Diff for: tests/sap/aibus/dar/client/test_inference_client.py

+38
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
from unittest.mock import call
88

99
import pytest
10+
from requests import RequestException, Timeout
1011

12+
from sap.aibus.dar.client.exceptions import DARHTTPException
1113
from sap.aibus.dar.client.inference_client import InferenceClient
1214
from tests.sap.aibus.dar.client.test_data_manager_client import (
1315
AbstractDARClientConstruction,
1416
prepare_client,
1517
)
18+
from tests.sap.aibus.dar.client.test_exceptions import create_mock_response_404
1619

1720

1821
class TestInferenceClientConstruction(AbstractDARClientConstruction):
@@ -203,3 +206,38 @@ def _assert_bulk_inference_works(
203206
inference_client.session.post_to_endpoint.call_args_list
204207
== expected_calls_to_post
205208
)
209+
210+
def test_bulk_inference_error(self, inference_client: InferenceClient):
211+
"""
212+
Tests if do_bulk_inference method will recover from errors.
213+
"""
214+
215+
response_404 = create_mock_response_404()
216+
url = "http://localhost:4321/test/"
217+
218+
exception_404 = DARHTTPException.create_from_response(url, response_404)
219+
220+
exceptions = [exception_404, RequestException, Timeout]
221+
# Try different exceptions
222+
for exc in exceptions:
223+
inference_client.session.post_to_endpoint.return_value.json.side_effect = [
224+
self.inference_response(50),
225+
exc,
226+
self.inference_response(40),
227+
]
228+
229+
many_objects = [self.objects[0] for _ in range(50 + 50 + 40)]
230+
assert len(many_objects) == 50 + 50 + 40
231+
232+
response = inference_client.do_bulk_inference(
233+
model_name="test-model",
234+
objects=many_objects,
235+
top_n=4,
236+
)
237+
238+
expected_response = []
239+
expected_response.extend(self.inference_response(50)["predictions"])
240+
expected_response.extend(None for _ in range(50))
241+
expected_response.extend(self.inference_response(40)["predictions"])
242+
243+
assert response == expected_response

0 commit comments

Comments
 (0)