diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a00f8a7..b68f51d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_language_version: python: python3.7 repos: - repo: https://github.com/ambv/black - rev: 21.6b0 + rev: 22.3.0 hooks: - id: black language_version: python3.7 diff --git a/.travis.yml b/.travis.yml index c60b84e..13a52d5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -43,6 +43,8 @@ stages: language: python install: + # Fix build on 3.7: https://github.com/pypa/setuptools/issues/3293 +- pip3 install 'setuptools==60.9.0;python_version=="3.7"' - pip3 install -r requirements-test.txt script: tox diff --git a/CHANGELOG.md b/CHANGELOG.md index a07f582..46e63c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.12.0] + +### Changed + +* `InferenceClient.do_bulk_inference` is now faster due to processing requests in parallel [#128] +* `InferenceClient.do_bulk_inference` is more resilient and handles errors internally. Instead of + raising an Exception if the inference request to the service fails, the `do_bulk_inference` method + will place a special error response object in the returned list. This can be considered a breaking API change, + because the special error response object will have a value of `None` for the `labels` key. + As this project is still versioned below 1.0.0, the breaking API change does not warrant a major version update. + See [#128] for details. + +[#128]: https://github.com/SAP/data-attribute-recommendation-python-sdk/pull/128 + ## [0.11.0] +### Added + * Support for reading training jobs using model name in `read_job_by_model_name` [#124]: https://github.com/SAP/data-attribute-recommendation-python-sdk/pull/124 diff --git a/docs/requirements.txt b/docs/requirements.txt index 05b0310..4fdec42 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,3 @@ # docs -sphinx==2.4.1 -sphinx_rtd_theme==0.5.0 +sphinx==5.0.2 +sphinx_rtd_theme==1.0.0 diff --git a/sap/aibus/dar/client/exceptions.py b/sap/aibus/dar/client/exceptions.py index 8541efb..ccf6b47 100644 --- a/sap/aibus/dar/client/exceptions.py +++ b/sap/aibus/dar/client/exceptions.py @@ -123,6 +123,14 @@ class JobNotFound(DARException): pass +class InvalidWorkerCount(DARException): + """ + Invalid worker_count parameter is specified. + + .. versionadded:: 0.12.0 + """ + + class ModelAlreadyExists(DARException): """ Model already exists and must be deleted first. diff --git a/sap/aibus/dar/client/inference_client.py b/sap/aibus/dar/client/inference_client.py index 9ca67b9..3f8c218 100644 --- a/sap/aibus/dar/client/inference_client.py +++ b/sap/aibus/dar/client/inference_client.py @@ -1,9 +1,13 @@ """ Client API for the Inference microservice. """ -from typing import List +from concurrent.futures import ThreadPoolExecutor +from typing import List, Union + +from requests import RequestException from sap.aibus.dar.client.base_client import BaseClientWithSession +from sap.aibus.dar.client.exceptions import DARHTTPException, InvalidWorkerCount from sap.aibus.dar.client.inference_constants import InferencePaths from sap.aibus.dar.client.util.lists import split_list @@ -13,6 +17,8 @@ #: How many labels to predict for a single object by default TOP_N = 1 +# pylint: disable=too-many-arguments + class InferenceClient(BaseClientWithSession): """ @@ -73,30 +79,53 @@ def do_bulk_inference( objects: List[dict], top_n: int = TOP_N, retry: bool = True, - ) -> List[dict]: + worker_count: int = 4, + ) -> List[Union[dict, None]]: """ Performs bulk inference for larger collections. For *objects* collections larger than *LIMIT_OBJECTS_PER_CALL*, splits the data into several smaller Inference requests. + Requests are executed in parallel. + Returns the aggregated values of the *predictions* of the original API response - as returned by :meth:`create_inference_request`. + as returned by :meth:`create_inference_request`. If one of the inference + requests to the service fails, an artificial prediction object is inserted with + the `labels` key set to `None` for each of the objects in the failing request. + + Example of a prediction object which indicates an error: + + .. code-block:: python + + { + 'objectId': 'b5cbcb34-7ab9-4da5-b7ec-654c90757eb9', + 'labels': None, + '_sdk_error': 'RequestException: Request Error' + } + + In case the `objects` passed to this method do not contain the `objectId` field, + the value is set to `None` in the error prediction object: + + .. code-block:: python + + { + 'objectId': None, + 'labels': None, + '_sdk_error': 'RequestException: Request Error' + } + .. note:: This method calls the inference endpoint multiple times to process all data. For non-trial service instances, each call will incur a cost. - If one of the calls fails, this method will raise an Exception and the - progress will be lost. In this case, all calls until the Exception happened - will be charged. - - To reduce the likelihood of a failed request terminating the bulk inference - process, this method will retry failed requests. + To reduce the impact of a failed request, this method will retry failed + requests. There is a small chance that even retried requests will be charged, e.g. - if a problem occurs with the request on the client side outside of the + if a problem occurs with the request on the client side outside the control of the service and after the service has processed the request. To disable `retry` behavior simply pass `retry=False` to the method. @@ -107,20 +136,80 @@ def do_bulk_inference( The default for the `retry` parameter changed from `retry=False` to `retry=True` for increased reliability in day-to-day operations. + .. versionchanged:: 0.12.0 + Requests are now executed in parallel with up to four threads. + + Errors are now handled in this method instead of raising an exception and + discarding inference results from previous requests. For objects where the + inference request did not succeed, a replacement `dict` object is placed in + the returned `list`. + This `dict` follows the format of the `ObjectPrediction` object sent by the + service. To indicate that this is a client-side generated placeholder, the + `labels` key for all ObjectPrediction dicts of the failed inference request + has value `None`. + A `_sdk_error` key is added with the Exception details. + + .. versionadded:: 0.12.0 + The `worker_count` parameter allows to fine-tune the number of concurrent + request threads. Set `worker_count` to `1` to disable concurrent execution of + requests. + :param model_name: name of the model used for inference :param objects: Objects to be classified :param top_n: How many predictions to return per object :param retry: whether to retry on errors. Default: True + :param worker_count: maximum number of concurrent requests + :raises: InvalidWorkerCount if worker_count param is incorrect :return: the aggregated ObjectPrediction dictionaries """ - result = [] # type: List[dict] - for work_package in split_list(objects, LIMIT_OBJECTS_PER_CALL): - response = self.create_inference_request( - model_name, work_package, top_n=top_n, retry=retry + + if worker_count is None: + raise InvalidWorkerCount("worker_count cannot be None!") + + if worker_count > 4: + msg = "worker_count too high: %s. Up to 4 allowed." % worker_count + raise InvalidWorkerCount(msg) + + if worker_count <= 0: + msg = "worker_count must be greater than 0!" + raise InvalidWorkerCount(msg) + + def predict_call(work_package): + try: + response = self.create_inference_request( + model_name, work_package, top_n=top_n, retry=retry + ) + return response["predictions"] + except (DARHTTPException, RequestException) as exc: + self.log.warning( + "Caught %s during bulk inference. " + "Setting results to None for this batch!", + exc, + exc_info=True, + ) + + prediction_error = [ + { + "objectId": inference_object.get("objectId", None), + "labels": None, + "_sdk_error": "{}: {}".format(exc.__class__.__name__, str(exc)), + } + for inference_object in work_package + ] + return prediction_error + + results = [] + + with ThreadPoolExecutor(max_workers=worker_count) as pool: + results_iterator = pool.map( + predict_call, split_list(objects, LIMIT_OBJECTS_PER_CALL) ) - result.extend(response["predictions"]) - return result + + for predictions in results_iterator: + results.extend(predictions) + + return results def create_inference_request_with_url( self, diff --git a/sap/aibus/dar/client/model_manager_client.py b/sap/aibus/dar/client/model_manager_client.py index ca94e38..fd7113c 100644 --- a/sap/aibus/dar/client/model_manager_client.py +++ b/sap/aibus/dar/client/model_manager_client.py @@ -174,7 +174,7 @@ def create_job( :param model_template_id: Model template ID for training :param business_blueprint_id: Business Blueprint template ID for training :raises CreateTrainingJobFailed: When business_blueprint_id - and model_template_id are provided or when both are not provided + and model_template_id are provided or when both are not provided :return: newly created Job as dict """ self.log.info( diff --git a/system_tests/workflow/test_end_to_end.py b/system_tests/workflow/test_end_to_end.py index e8b3d39..28829b4 100644 --- a/system_tests/workflow/test_end_to_end.py +++ b/system_tests/workflow/test_end_to_end.py @@ -249,12 +249,19 @@ def _assert_inference_works(self, inference_client, model_name): # One object has been classified assert len(response["predictions"]) == 1 + # do_bulk_inference with concurrency big_to_be_classified = [to_be_classified[0] for _ in range(123)] response = inference_client.do_bulk_inference( model_name=model_name, objects=big_to_be_classified ) assert len(response) == 123 + # do_bulk_inference without concurrency + response = inference_client.do_bulk_inference( + model_name=model_name, objects=big_to_be_classified, worker_count=1 + ) + assert len(response) == 123 + url = os.environ["DAR_URL"] if url[-1] == "/": url = url[:-1] diff --git a/tests/sap/aibus/dar/client/test_exceptions.py b/tests/sap/aibus/dar/client/test_exceptions.py index 4d0e81c..7bf17a8 100644 --- a/tests/sap/aibus/dar/client/test_exceptions.py +++ b/tests/sap/aibus/dar/client/test_exceptions.py @@ -1,29 +1,38 @@ import datetime from unittest.mock import PropertyMock -from sap.aibus.dar.client.exceptions import DARHTTPException, ModelAlreadyExists +from sap.aibus.dar.client.exceptions import ( + DARHTTPException, + ModelAlreadyExists, +) from tests.sap.aibus.dar.client.test_dar_session import create_mock_response # TODO: test __str__ url = "http://localhost:4321/test/" +correlation_id = "412d84ae-0eb5-4421-863d-956570c2da54" +vcap_request_id = "d9cd7dec-4d74-4a7a-a953-4ca583c8d912" + + +def create_mock_response_404(): + mock_response = create_mock_response() + + mock_response.headers["X-Correlation-ID"] = correlation_id + mock_response.headers["X-Vcap-Request-Id"] = vcap_request_id + mock_response.headers["Server"] = "Gunicorn" + mock_response.headers["X-Cf-Routererror"] = "unknown_route" + mock_response.status_code = 404 + mock_response.request.method = "GET" + mock_response.reason = b"\xc4\xd6\xdc Not Found" + return mock_response + class TestDARHTTPException: url = "http://localhost:4321/test/" def test_basic(self): - mock_response = create_mock_response() - - correlation_id = "412d84ae-0eb5-4421-863d-956570c2da54" - mock_response.headers["X-Correlation-ID"] = correlation_id - vcap_request_id = "d9cd7dec-4d74-4a7a-a953-4ca583c8d912" - mock_response.headers["X-Vcap-Request-Id"] = vcap_request_id - mock_response.headers["Server"] = "Gunicorn" - mock_response.headers["X-Cf-Routererror"] = "unknown_route" - mock_response.status_code = 404 - mock_response.request.method = "GET" - mock_response.reason = b"\xc4\xd6\xdc Not Found" + mock_response = create_mock_response_404() exception = DARHTTPException.create_from_response(url, mock_response) @@ -130,7 +139,6 @@ class TestDARHTTPExceptionReason: # status line: https://tools.ietf.org/html/rfc7230#section-3.1.2 def test_reason_works_iso8859_1(self): - mock_response = create_mock_response() # ÄÖÜ encoded as ISO-8859-1 mock_response.reason = b"\xc4\xd6\xdc" diff --git a/tests/sap/aibus/dar/client/test_inference_client.py b/tests/sap/aibus/dar/client/test_inference_client.py index 44c308f..1577b6e 100644 --- a/tests/sap/aibus/dar/client/test_inference_client.py +++ b/tests/sap/aibus/dar/client/test_inference_client.py @@ -2,17 +2,19 @@ # The pragma above causes mypy to ignore this file: # mypy cannot deal with some of the monkey-patching we do below. # https://github.com/python/mypy/issues/2427 - - -from unittest.mock import call +from typing import Optional +from unittest.mock import call, Mock import pytest +from requests import RequestException, Timeout +from sap.aibus.dar.client.exceptions import DARHTTPException, InvalidWorkerCount from sap.aibus.dar.client.inference_client import InferenceClient from tests.sap.aibus.dar.client.test_data_manager_client import ( AbstractDARClientConstruction, prepare_client, ) +from tests.sap.aibus.dar.client.test_exceptions import create_mock_response_404 class TestInferenceClientConstruction(AbstractDARClientConstruction): @@ -28,20 +30,24 @@ def inference_client(): class TestInferenceClient: - @property - def objects(self): + def objects( + self, object_id: Optional[str] = "b5cbcb34-7ab9-4da5-b7ec-654c90757eb9" + ): """ Returns sample Objects used as classification inputs. """ return [ { - "objectId": "b5cbcb34-7ab9-4da5-b7ec-654c90757eb9", + "objectId": object_id, "features": [{"name": "manufacturer", "value": "ACME"}], } ] @staticmethod - def inference_response(prediction_count): + def inference_response( + prediction_count, + object_id: Optional[str] = "b5cbcb34-7ab9-4da5-b7ec-654c90757eb9", + ): """ Returns a sample InferenceResponseSchema with the given number of predictions. @@ -52,7 +58,7 @@ def inference_response(prediction_count): "processedTime": "2018-08-31T11:45:54.727934+00:00", "predictions": [ { - "objectId": "b5cbcb34-7ab9-4da5-b7ec-654c90757eb9", + "objectId": object_id, "labels": [{"name": "category", "value": "ANVIL"}], } for _ in range(prediction_count) @@ -64,12 +70,12 @@ def test_create_inference_request(self, inference_client: InferenceClient): Checks inference call. """ response = inference_client.create_inference_request( - "my-model", objects=self.objects + "my-model", objects=self.objects() ) expected_call = call( "/inference/api/v3/models/my-model/versions/1", - payload={"topN": 1, "objects": self.objects}, + payload={"topN": 1, "objects": self.objects()}, retry=False, ) @@ -89,11 +95,11 @@ def test_create_inference_request_with_top_n( Checks if top_n parameter is passed correctly. """ response = inference_client.create_inference_request( - "a-test-model", objects=self.objects, top_n=99, retry=False + "a-test-model", objects=self.objects(), top_n=99, retry=False ) expected_call = call( "/inference/api/v3/models/a-test-model/versions/1", - payload={"topN": 99, "objects": self.objects}, + payload={"topN": 99, "objects": self.objects()}, retry=False, ) @@ -113,12 +119,12 @@ def test_create_inference_request_with_retry_enabled( Checks if retry parameter is passsed correctly. """ response = inference_client.create_inference_request( - "my-model", objects=self.objects, retry=True + "my-model", objects=self.objects(), retry=True ) expected_call = call( "/inference/api/v3/models/my-model/versions/1", - payload={"topN": 1, "objects": self.objects}, + payload={"topN": 1, "objects": self.objects()}, retry=True, ) @@ -159,7 +165,7 @@ def _assert_bulk_inference_works( # passed to InferenceClient.do_bulk_inference - the default is assumed to be # False and the internal calls to Inference.create_inference_request will # be checked for this. - many_objects = [self.objects[0] for _ in range(75)] + many_objects = [self.objects()[0] for _ in range(75)] assert len(many_objects) == 75 # On first call, return response with 50 predictions. On second call, @@ -174,7 +180,11 @@ def _assert_bulk_inference_works( retry_kwarg["retry"] = retry_flag response = inference_client.do_bulk_inference( - model_name="test-model", objects=many_objects, top_n=4, **retry_kwarg + model_name="test-model", + objects=many_objects, + top_n=4, + worker_count=1, # Disable concurrency to make tests deterministic. + **retry_kwarg, ) # The return value is the concatenation of all 'predictions' of the individual @@ -210,12 +220,12 @@ def test_create_inference_with_url_works(self, inference_client: InferenceClient """ url = DAR_URL + "inference/api/v3/models/my-model/versions/1" response = inference_client.create_inference_request_with_url( - url, objects=self.objects + url, objects=self.objects() ) expected_call = call( url, - payload={"topN": 1, "objects": self.objects}, + payload={"topN": 1, "objects": self.objects()}, retry=False, ) @@ -235,12 +245,12 @@ def test_create_inference_request_with_url_retry_enabled( url = DAR_URL + "inference/api/v3/models/my-model/versions/1" response = inference_client.create_inference_request_with_url( - url=url, objects=self.objects, retry=True + url=url, objects=self.objects(), retry=True ) expected_call = call( url, - payload={"topN": 1, "objects": self.objects}, + payload={"topN": 1, "objects": self.objects()}, retry=True, ) @@ -250,3 +260,146 @@ def test_create_inference_request_with_url_retry_enabled( inference_client.session.post_to_url.return_value.json.return_value == response ) + + def test_bulk_inference_error(self, inference_client: InferenceClient): + """ + Tests if do_bulk_inference method will recover from errors. + """ + + response_404 = create_mock_response_404() + url = "http://localhost:4321/test/" + + exception_404 = DARHTTPException.create_from_response(url, response_404) + + # The old trick to return different values in a Mock based on the call order + # does not work here because the code is concurrent. Instead, we use a different + # objectId for those objects where we want the request to fail + def make_mock_post(exc): + def post_to_endpoint(*args, **kwargs): + payload = kwargs.pop("payload") + object_id = payload["objects"][0]["objectId"] + if object_id == "expected-to-fail": + raise exc + elif object_id == "b5cbcb34-7ab9-4da5-b7ec-654c90757eb9": + response = Mock() + response.json.return_value = self.inference_response( + len(payload["objects"]) + ) + return response + else: + raise ValueError("objectId '%s' not handled in test." % object_id) + + return post_to_endpoint + + # Try different exceptions + exceptions = [ + exception_404, + RequestException("Request Error"), + Timeout("Timeout"), + ] + for exc in exceptions: + inference_client.session.post_to_endpoint.side_effect = make_mock_post(exc) + + many_objects = [] + many_objects.extend([self.objects()[0] for _ in range(50)]) + many_objects.extend( + [self.objects(object_id="expected-to-fail")[0] for _ in range(50)] + ) + many_objects.extend([self.objects()[0] for _ in range(40)]) + assert len(many_objects) == 50 + 50 + 40 + + response = inference_client.do_bulk_inference( + model_name="test-model", + objects=many_objects, + top_n=4, + ) + + expected_error_response = { + "objectId": "expected-to-fail", + "labels": None, + # If this test fails, I found it can make pytest/PyCharm hang because it + # takes too much time in difflib. + "_sdk_error": "{}: {}".format(exc.__class__.__name__, str(exc)), + } + + expected_response = [] + expected_response.extend(self.inference_response(50)["predictions"]) + expected_response.extend(expected_error_response for _ in range(50)) + expected_response.extend(self.inference_response(40)["predictions"]) + + assert len(response) == len(expected_response) + assert response == expected_response + + def test_bulk_inference_error_no_object_ids( + self, inference_client: InferenceClient + ): + response_404 = create_mock_response_404() + url = "http://localhost:4321/test/" + + exception_404 = DARHTTPException.create_from_response(url, response_404) + + inference_client.session.post_to_endpoint.return_value.json.side_effect = [ + self.inference_response(50, object_id=None), + exception_404, + self.inference_response(22, object_id=None), + ] + + inference_objects = [ + self.objects(object_id=None)[0] for _ in range(50 + 50 + 22) + ] + + response = inference_client.do_bulk_inference( + model_name="test-model", + objects=inference_objects, + top_n=4, + worker_count=1, # disable concurrency to make tests deterministic + ) + expected_error_response = { + "objectId": None, + "labels": None, + # If this test fails, I found it can make pytest/PyCharm hang because it + # takes too much time in difflib. + "_sdk_error": "{}: {}".format( + exception_404.__class__.__name__, str(exception_404) + ), + } + expected_response = [] + expected_response.extend( + self.inference_response(50, object_id=None)["predictions"] + ) + expected_response.extend(expected_error_response for _ in range(50)) + expected_response.extend( + self.inference_response(22, object_id=None)["predictions"] + ) + + assert response == expected_response + + def test_worker_count_validation(self, inference_client: InferenceClient): + + many_objects = [self.objects()[0] for _ in range(75)] + + with pytest.raises(InvalidWorkerCount) as context: + inference_client.do_bulk_inference( + model_name="test-model", objects=many_objects, worker_count=5 + ) + assert "worker_count too high: 5. Up to 4 allowed." in str(context.value) + + with pytest.raises(InvalidWorkerCount) as context: + inference_client.do_bulk_inference( + model_name="test-model", objects=many_objects, worker_count=0 + ) + assert "worker_count must be greater than 0" in str(context.value) + + with pytest.raises(InvalidWorkerCount) as context: + inference_client.do_bulk_inference( + model_name="test-model", objects=many_objects, worker_count=-1 + ) + assert "worker_count must be greater than 0" in str(context.value) + + with pytest.raises(InvalidWorkerCount) as context: + inference_client.do_bulk_inference( + model_name="test-model", + objects=many_objects, + worker_count=None, + ) + assert "worker_count cannot be None" in str(context.value) diff --git a/tox.ini b/tox.ini index 576459c..7d7d920 100644 --- a/tox.ini +++ b/tox.ini @@ -30,6 +30,7 @@ commands = -o junit_suite_name={envname} \ -o console_output_style=classic \ -o junit_family=xunit2 \ + -vv \ tests/ \ sap/ cov: coveralls diff --git a/version.txt b/version.txt index d9df1bb..ac454c6 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.11.0 +0.12.0