Skip to content

Improve resilience and performance of do_bulk_inference #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ default_language_version:
python: python3.7
repos:
- repo: https://github.com/ambv/black
rev: 21.6b0
rev: 22.3.0
hooks:
- id: black
language_version: python3.7
Expand Down
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ stages:

language: python
install:
# Fix build on 3.7: https://github.com/pypa/setuptools/issues/3293
- pip3 install 'setuptools==60.9.0;python_version=="3.7"'
- pip3 install -r requirements-test.txt
script: tox

Expand Down
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.12.0]

### Changed

* `InferenceClient.do_bulk_inference` is now faster due to processing requests in parallel [#128]
* `InferenceClient.do_bulk_inference` is more resilient and handles errors internally. Instead of
raising an Exception if the inference request to the service fails, the `do_bulk_inference` method
will place a special error response object in the returned list. This can be considered a breaking API change,
because the special error response object will have a value of `None` for the `labels` key.
As this project is still versioned below 1.0.0, the breaking API change does not warrant a major version update.
See [#128] for details.

[#128]: https://github.com/SAP/data-attribute-recommendation-python-sdk/pull/128

## [0.11.0]

### Added

* Support for reading training jobs using model name in `read_job_by_model_name`

[#124]: https://github.com/SAP/data-attribute-recommendation-python-sdk/pull/124
Expand Down
4 changes: 2 additions & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# docs
sphinx==2.4.1
sphinx_rtd_theme==0.5.0
sphinx==5.0.2
sphinx_rtd_theme==1.0.0
8 changes: 8 additions & 0 deletions sap/aibus/dar/client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ class JobNotFound(DARException):
pass


class InvalidWorkerCount(DARException):
"""
Invalid worker_count parameter is specified.

.. versionadded:: 0.12.0
"""


class ModelAlreadyExists(DARException):
"""
Model already exists and must be deleted first.
Expand Down
121 changes: 105 additions & 16 deletions sap/aibus/dar/client/inference_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""
Client API for the Inference microservice.
"""
from typing import List
from concurrent.futures import ThreadPoolExecutor
from typing import List, Union

from requests import RequestException

from sap.aibus.dar.client.base_client import BaseClientWithSession
from sap.aibus.dar.client.exceptions import DARHTTPException, InvalidWorkerCount
from sap.aibus.dar.client.inference_constants import InferencePaths
from sap.aibus.dar.client.util.lists import split_list

Expand All @@ -13,6 +17,8 @@
#: How many labels to predict for a single object by default
TOP_N = 1

# pylint: disable=too-many-arguments


class InferenceClient(BaseClientWithSession):
"""
Expand Down Expand Up @@ -73,30 +79,53 @@ def do_bulk_inference(
objects: List[dict],
top_n: int = TOP_N,
retry: bool = True,
) -> List[dict]:
worker_count: int = 4,
) -> List[Union[dict, None]]:
"""
Performs bulk inference for larger collections.

For *objects* collections larger than *LIMIT_OBJECTS_PER_CALL*, splits
the data into several smaller Inference requests.

Requests are executed in parallel.

Returns the aggregated values of the *predictions* of the original API response
as returned by :meth:`create_inference_request`.
as returned by :meth:`create_inference_request`. If one of the inference
requests to the service fails, an artificial prediction object is inserted with
the `labels` key set to `None` for each of the objects in the failing request.

Example of a prediction object which indicates an error:

.. code-block:: python

{
'objectId': 'b5cbcb34-7ab9-4da5-b7ec-654c90757eb9',
'labels': None,
'_sdk_error': 'RequestException: Request Error'
}

In case the `objects` passed to this method do not contain the `objectId` field,
the value is set to `None` in the error prediction object:

.. code-block:: python

{
'objectId': None,
'labels': None,
'_sdk_error': 'RequestException: Request Error'
}


.. note::

This method calls the inference endpoint multiple times to process all data.
For non-trial service instances, each call will incur a cost.

If one of the calls fails, this method will raise an Exception and the
progress will be lost. In this case, all calls until the Exception happened
will be charged.

To reduce the likelihood of a failed request terminating the bulk inference
process, this method will retry failed requests.
To reduce the impact of a failed request, this method will retry failed
requests.

There is a small chance that even retried requests will be charged, e.g.
if a problem occurs with the request on the client side outside of the
if a problem occurs with the request on the client side outside the
control of the service and after the service has processed the request.
To disable `retry` behavior simply pass `retry=False` to the method.

Expand All @@ -107,20 +136,80 @@ def do_bulk_inference(
The default for the `retry` parameter changed from `retry=False` to
`retry=True` for increased reliability in day-to-day operations.

.. versionchanged:: 0.12.0
Requests are now executed in parallel with up to four threads.

Errors are now handled in this method instead of raising an exception and
discarding inference results from previous requests. For objects where the
inference request did not succeed, a replacement `dict` object is placed in
the returned `list`.
This `dict` follows the format of the `ObjectPrediction` object sent by the
service. To indicate that this is a client-side generated placeholder, the
`labels` key for all ObjectPrediction dicts of the failed inference request
has value `None`.
A `_sdk_error` key is added with the Exception details.

.. versionadded:: 0.12.0
The `worker_count` parameter allows to fine-tune the number of concurrent
request threads. Set `worker_count` to `1` to disable concurrent execution of
requests.


:param model_name: name of the model used for inference
:param objects: Objects to be classified
:param top_n: How many predictions to return per object
:param retry: whether to retry on errors. Default: True
:param worker_count: maximum number of concurrent requests
:raises: InvalidWorkerCount if worker_count param is incorrect
:return: the aggregated ObjectPrediction dictionaries
"""
result = [] # type: List[dict]
for work_package in split_list(objects, LIMIT_OBJECTS_PER_CALL):
response = self.create_inference_request(
model_name, work_package, top_n=top_n, retry=retry

if worker_count is None:
raise InvalidWorkerCount("worker_count cannot be None!")

if worker_count > 4:
msg = "worker_count too high: %s. Up to 4 allowed." % worker_count
raise InvalidWorkerCount(msg)

if worker_count <= 0:
msg = "worker_count must be greater than 0!"
raise InvalidWorkerCount(msg)

def predict_call(work_package):
try:
response = self.create_inference_request(
model_name, work_package, top_n=top_n, retry=retry
)
return response["predictions"]
except (DARHTTPException, RequestException) as exc:
self.log.warning(
"Caught %s during bulk inference. "
"Setting results to None for this batch!",
exc,
exc_info=True,
)

prediction_error = [
{
"objectId": inference_object.get("objectId", None),
"labels": None,
"_sdk_error": "{}: {}".format(exc.__class__.__name__, str(exc)),
}
for inference_object in work_package
]
return prediction_error

results = []

with ThreadPoolExecutor(max_workers=worker_count) as pool:
results_iterator = pool.map(
predict_call, split_list(objects, LIMIT_OBJECTS_PER_CALL)
)
result.extend(response["predictions"])
return result

for predictions in results_iterator:
results.extend(predictions)

return results

def create_inference_request_with_url(
self,
Expand Down
2 changes: 1 addition & 1 deletion sap/aibus/dar/client/model_manager_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def create_job(
:param model_template_id: Model template ID for training
:param business_blueprint_id: Business Blueprint template ID for training
:raises CreateTrainingJobFailed: When business_blueprint_id
and model_template_id are provided or when both are not provided
and model_template_id are provided or when both are not provided
:return: newly created Job as dict
"""
self.log.info(
Expand Down
7 changes: 7 additions & 0 deletions system_tests/workflow/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,19 @@ def _assert_inference_works(self, inference_client, model_name):
# One object has been classified
assert len(response["predictions"]) == 1

# do_bulk_inference with concurrency
big_to_be_classified = [to_be_classified[0] for _ in range(123)]
response = inference_client.do_bulk_inference(
model_name=model_name, objects=big_to_be_classified
)
assert len(response) == 123

# do_bulk_inference without concurrency
response = inference_client.do_bulk_inference(
model_name=model_name, objects=big_to_be_classified, worker_count=1
)
assert len(response) == 123

url = os.environ["DAR_URL"]
if url[-1] == "/":
url = url[:-1]
Expand Down
34 changes: 21 additions & 13 deletions tests/sap/aibus/dar/client/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,38 @@
import datetime
from unittest.mock import PropertyMock

from sap.aibus.dar.client.exceptions import DARHTTPException, ModelAlreadyExists
from sap.aibus.dar.client.exceptions import (
DARHTTPException,
ModelAlreadyExists,
)
from tests.sap.aibus.dar.client.test_dar_session import create_mock_response

# TODO: test __str__

url = "http://localhost:4321/test/"

correlation_id = "412d84ae-0eb5-4421-863d-956570c2da54"
vcap_request_id = "d9cd7dec-4d74-4a7a-a953-4ca583c8d912"


def create_mock_response_404():
mock_response = create_mock_response()

mock_response.headers["X-Correlation-ID"] = correlation_id
mock_response.headers["X-Vcap-Request-Id"] = vcap_request_id
mock_response.headers["Server"] = "Gunicorn"
mock_response.headers["X-Cf-Routererror"] = "unknown_route"
mock_response.status_code = 404
mock_response.request.method = "GET"
mock_response.reason = b"\xc4\xd6\xdc Not Found"
return mock_response


class TestDARHTTPException:
url = "http://localhost:4321/test/"

def test_basic(self):
mock_response = create_mock_response()

correlation_id = "412d84ae-0eb5-4421-863d-956570c2da54"
mock_response.headers["X-Correlation-ID"] = correlation_id
vcap_request_id = "d9cd7dec-4d74-4a7a-a953-4ca583c8d912"
mock_response.headers["X-Vcap-Request-Id"] = vcap_request_id
mock_response.headers["Server"] = "Gunicorn"
mock_response.headers["X-Cf-Routererror"] = "unknown_route"
mock_response.status_code = 404
mock_response.request.method = "GET"
mock_response.reason = b"\xc4\xd6\xdc Not Found"
mock_response = create_mock_response_404()

exception = DARHTTPException.create_from_response(url, mock_response)

Expand Down Expand Up @@ -130,7 +139,6 @@ class TestDARHTTPExceptionReason:
# status line: https://tools.ietf.org/html/rfc7230#section-3.1.2

def test_reason_works_iso8859_1(self):

mock_response = create_mock_response()
# ÄÖÜ encoded as ISO-8859-1
mock_response.reason = b"\xc4\xd6\xdc"
Expand Down
Loading