Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ default_language_version:
python: python3.7
repos:
- repo: https://github.com/ambv/black
rev: 21.6b0
rev: 22.3.0
hooks:
- id: black
language_version: python3.7
Expand Down
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ stages:

language: python
install:
# Fix build on 3.7: https://github.com/pypa/setuptools/issues/3293
- pip3 install 'setuptools==60.9.0;python_version=="3.7"'
- pip3 install -r requirements-test.txt
script: tox

Expand Down
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.12.0]

### Changed

* `InferenceClient.do_bulk_inference` is now faster due to processing requests in parallel [#128]
* `InferenceClient.do_bulk_inference` is more resilient and handles errors internally. Instead of
raising an Exception if the inference request to the service fails, the `do_bulk_inference` method
will place a special error response object in the returned list. This can be considered a breaking API change,
because the special error response object will have a value of `None` for the `labels` key.
As this project is still versioned below 1.0.0, the breaking API change does not warrant a major version update.
See [#128] for details.

[#128]: https://github.com/SAP/data-attribute-recommendation-python-sdk/pull/128

## [0.11.0]

### Added

* Support for reading training jobs using model name in `read_job_by_model_name`

[#124]: https://github.com/SAP/data-attribute-recommendation-python-sdk/pull/124
Expand Down
100 changes: 84 additions & 16 deletions sap/aibus/dar/client/inference_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""
Client API for the Inference microservice.
"""
from typing import List
from concurrent.futures import ThreadPoolExecutor
from typing import List, Union

from requests import RequestException

from sap.aibus.dar.client.base_client import BaseClientWithSession
from sap.aibus.dar.client.exceptions import DARHTTPException
from sap.aibus.dar.client.inference_constants import InferencePaths
from sap.aibus.dar.client.util.lists import split_list

Expand Down Expand Up @@ -73,30 +77,52 @@ def do_bulk_inference(
objects: List[dict],
top_n: int = TOP_N,
retry: bool = True,
) -> List[dict]:
) -> List[Union[dict, None]]:
"""
Performs bulk inference for larger collections.

For *objects* collections larger than *LIMIT_OBJECTS_PER_CALL*, splits
the data into several smaller Inference requests.

Requests are executed in parallel.

Returns the aggregated values of the *predictions* of the original API response
as returned by :meth:`create_inference_request`.
as returned by :meth:`create_inference_request`. If one of the inference
requests to the service fails, an artificial prediction object is inserted with
the `labels` key set to `None` for each of the objects in the failing request.

Example of a prediction object which indicates an error:

.. code-block:: python

{
'objectId': 'b5cbcb34-7ab9-4da5-b7ec-654c90757eb9',
'labels': None,
'_sdk_error': 'RequestException: Request Error'
}

In case the `objects` passed to this method do not contain the `objectId` field,
the value is set to `None` in the error prediction object:

.. code-block:: python

{
'objectId': None,
'labels': None,
'_sdk_error': 'RequestException: Request Error'
}


.. note::

This method calls the inference endpoint multiple times to process all data.
For non-trial service instances, each call will incur a cost.

If one of the calls fails, this method will raise an Exception and the
progress will be lost. In this case, all calls until the Exception happened
will be charged.

To reduce the likelihood of a failed request terminating the bulk inference
process, this method will retry failed requests.
To reduce the impact of a failed request, this method will retry failed
requests.

There is a small chance that even retried requests will be charged, e.g.
if a problem occurs with the request on the client side outside of the
if a problem occurs with the request on the client side outside the
control of the service and after the service has processed the request.
To disable `retry` behavior simply pass `retry=False` to the method.

Expand All @@ -107,20 +133,62 @@ def do_bulk_inference(
The default for the `retry` parameter changed from `retry=False` to
`retry=True` for increased reliability in day-to-day operations.

.. versionchanged:: 0.12.0
Requests are now executed in parallel with up to four threads.

Errors are now handled in this method instead of raising an exception and
discarding inference results from previous requests. For objects where the
inference request did not succeed, a replacement `dict` object is placed in
the returned `list`.
This `dict` follows the format of the `ObjectPrediction` object sent by the
service. To indicate that this is a client-side generated placeholder, the
`labels` key for all ObjectPrediction dicts of the failed inference request
has value `None`.
A `_sdk_error` key is added with the Exception details.


:param model_name: name of the model used for inference
:param objects: Objects to be classified
:param top_n: How many predictions to return per object
:param retry: whether to retry on errors. Default: True
:return: the aggregated ObjectPrediction dictionaries
"""
result = [] # type: List[dict]
for work_package in split_list(objects, LIMIT_OBJECTS_PER_CALL):
response = self.create_inference_request(
model_name, work_package, top_n=top_n, retry=retry

def predict_call(work_package):
try:
response = self.create_inference_request(
model_name, work_package, top_n=top_n, retry=retry
)
return response["predictions"]
except (DARHTTPException, RequestException) as exc:
self.log.warning(
"Caught %s during bulk inference. "
"Setting results to None for this batch!",
exc,
exc_info=True,
)

prediction_error = [
{
"objectId": inference_object.get("objectId", None),
"labels": None,
"_sdk_error": "{}: {}".format(exc.__class__.__name__, str(exc)),
}
for inference_object in work_package
]
return prediction_error

results = []

with ThreadPoolExecutor(max_workers=4) as pool:
results_iterator = pool.map(
predict_call, split_list(objects, LIMIT_OBJECTS_PER_CALL)
)
result.extend(response["predictions"])
return result

for predictions in results_iterator:
results.extend(predictions)

return results

def create_inference_request_with_url(
self,
Expand Down
2 changes: 1 addition & 1 deletion sap/aibus/dar/client/model_manager_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def create_job(
:param model_template_id: Model template ID for training
:param business_blueprint_id: Business Blueprint template ID for training
:raises CreateTrainingJobFailed: When business_blueprint_id
and model_template_id are provided or when both are not provided
and model_template_id are provided or when both are not provided
:return: newly created Job as dict
"""
self.log.info(
Expand Down
29 changes: 17 additions & 12 deletions tests/sap/aibus/dar/client/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,28 @@

url = "http://localhost:4321/test/"

correlation_id = "412d84ae-0eb5-4421-863d-956570c2da54"
vcap_request_id = "d9cd7dec-4d74-4a7a-a953-4ca583c8d912"


def create_mock_response_404():
mock_response = create_mock_response()

mock_response.headers["X-Correlation-ID"] = correlation_id
mock_response.headers["X-Vcap-Request-Id"] = vcap_request_id
mock_response.headers["Server"] = "Gunicorn"
mock_response.headers["X-Cf-Routererror"] = "unknown_route"
mock_response.status_code = 404
mock_response.request.method = "GET"
mock_response.reason = b"\xc4\xd6\xdc Not Found"
return mock_response


class TestDARHTTPException:
url = "http://localhost:4321/test/"

def test_basic(self):
mock_response = create_mock_response()

correlation_id = "412d84ae-0eb5-4421-863d-956570c2da54"
mock_response.headers["X-Correlation-ID"] = correlation_id
vcap_request_id = "d9cd7dec-4d74-4a7a-a953-4ca583c8d912"
mock_response.headers["X-Vcap-Request-Id"] = vcap_request_id
mock_response.headers["Server"] = "Gunicorn"
mock_response.headers["X-Cf-Routererror"] = "unknown_route"
mock_response.status_code = 404
mock_response.request.method = "GET"
mock_response.reason = b"\xc4\xd6\xdc Not Found"
mock_response = create_mock_response_404()

exception = DARHTTPException.create_from_response(url, mock_response)

Expand Down Expand Up @@ -130,7 +136,6 @@ class TestDARHTTPExceptionReason:
# status line: https://tools.ietf.org/html/rfc7230#section-3.1.2

def test_reason_works_iso8859_1(self):

mock_response = create_mock_response()
# ÄÖÜ encoded as ISO-8859-1
mock_response.reason = b"\xc4\xd6\xdc"
Expand Down
Loading