11"""
22Client API for the Inference microservice.
33"""
4- from typing import List
4+ from concurrent .futures import ThreadPoolExecutor
5+ from typing import List , Union
6+
7+ from requests import RequestException
58
69from sap .aibus .dar .client .base_client import BaseClientWithSession
10+ from sap .aibus .dar .client .exceptions import DARHTTPException , InvalidWorkerCount
711from sap .aibus .dar .client .inference_constants import InferencePaths
812from sap .aibus .dar .client .util .lists import split_list
913
1317#: How many labels to predict for a single object by default
1418TOP_N = 1
1519
20+ # pylint: disable=too-many-arguments
21+
1622
1723class InferenceClient (BaseClientWithSession ):
1824 """
@@ -73,30 +79,53 @@ def do_bulk_inference(
7379 objects : List [dict ],
7480 top_n : int = TOP_N ,
7581 retry : bool = True ,
76- ) -> List [dict ]:
82+ worker_count : int = 4 ,
83+ ) -> List [Union [dict , None ]]:
7784 """
7885 Performs bulk inference for larger collections.
7986
8087 For *objects* collections larger than *LIMIT_OBJECTS_PER_CALL*, splits
8188 the data into several smaller Inference requests.
8289
90+ Requests are executed in parallel.
91+
8392 Returns the aggregated values of the *predictions* of the original API response
84- as returned by :meth:`create_inference_request`.
93+ as returned by :meth:`create_inference_request`. If one of the inference
94+ requests to the service fails, an artificial prediction object is inserted with
95+ the `labels` key set to `None` for each of the objects in the failing request.
96+
97+ Example of a prediction object which indicates an error:
98+
99+ .. code-block:: python
100+
101+ {
102+ 'objectId': 'b5cbcb34-7ab9-4da5-b7ec-654c90757eb9',
103+ 'labels': None,
104+ '_sdk_error': 'RequestException: Request Error'
105+ }
106+
107+ In case the `objects` passed to this method do not contain the `objectId` field,
108+ the value is set to `None` in the error prediction object:
109+
110+ .. code-block:: python
111+
112+ {
113+ 'objectId': None,
114+ 'labels': None,
115+ '_sdk_error': 'RequestException: Request Error'
116+ }
117+
85118
86119 .. note::
87120
88121 This method calls the inference endpoint multiple times to process all data.
89122 For non-trial service instances, each call will incur a cost.
90123
91- If one of the calls fails, this method will raise an Exception and the
92- progress will be lost. In this case, all calls until the Exception happened
93- will be charged.
94-
95- To reduce the likelihood of a failed request terminating the bulk inference
96- process, this method will retry failed requests.
124+ To reduce the impact of a failed request, this method will retry failed
125+ requests.
97126
98127 There is a small chance that even retried requests will be charged, e.g.
99- if a problem occurs with the request on the client side outside of the
128+ if a problem occurs with the request on the client side outside the
100129 control of the service and after the service has processed the request.
101130 To disable `retry` behavior simply pass `retry=False` to the method.
102131
@@ -107,20 +136,80 @@ def do_bulk_inference(
107136 The default for the `retry` parameter changed from `retry=False` to
108137 `retry=True` for increased reliability in day-to-day operations.
109138
139+ .. versionchanged:: 0.12.0
140+ Requests are now executed in parallel with up to four threads.
141+
142+ Errors are now handled in this method instead of raising an exception and
143+ discarding inference results from previous requests. For objects where the
144+ inference request did not succeed, a replacement `dict` object is placed in
145+ the returned `list`.
146+ This `dict` follows the format of the `ObjectPrediction` object sent by the
147+ service. To indicate that this is a client-side generated placeholder, the
148+ `labels` key for all ObjectPrediction dicts of the failed inference request
149+ has value `None`.
150+ A `_sdk_error` key is added with the Exception details.
151+
152+ .. versionadded:: 0.12.0
153+ The `worker_count` parameter allows to fine-tune the number of concurrent
154+ request threads. Set `worker_count` to `1` to disable concurrent execution of
155+ requests.
156+
110157
111158 :param model_name: name of the model used for inference
112159 :param objects: Objects to be classified
113160 :param top_n: How many predictions to return per object
114161 :param retry: whether to retry on errors. Default: True
162+ :param worker_count: maximum number of concurrent requests
163+ :raises: InvalidWorkerCount if worker_count param is incorrect
115164 :return: the aggregated ObjectPrediction dictionaries
116165 """
117- result = [] # type: List[dict]
118- for work_package in split_list (objects , LIMIT_OBJECTS_PER_CALL ):
119- response = self .create_inference_request (
120- model_name , work_package , top_n = top_n , retry = retry
166+
167+ if worker_count is None :
168+ raise InvalidWorkerCount ("worker_count cannot be None!" )
169+
170+ if worker_count > 4 :
171+ msg = "worker_count too high: %s. Up to 4 allowed." % worker_count
172+ raise InvalidWorkerCount (msg )
173+
174+ if worker_count <= 0 :
175+ msg = "worker_count must be greater than 0!"
176+ raise InvalidWorkerCount (msg )
177+
178+ def predict_call (work_package ):
179+ try :
180+ response = self .create_inference_request (
181+ model_name , work_package , top_n = top_n , retry = retry
182+ )
183+ return response ["predictions" ]
184+ except (DARHTTPException , RequestException ) as exc :
185+ self .log .warning (
186+ "Caught %s during bulk inference. "
187+ "Setting results to None for this batch!" ,
188+ exc ,
189+ exc_info = True ,
190+ )
191+
192+ prediction_error = [
193+ {
194+ "objectId" : inference_object .get ("objectId" , None ),
195+ "labels" : None ,
196+ "_sdk_error" : "{}: {}" .format (exc .__class__ .__name__ , str (exc )),
197+ }
198+ for inference_object in work_package
199+ ]
200+ return prediction_error
201+
202+ results = []
203+
204+ with ThreadPoolExecutor (max_workers = worker_count ) as pool :
205+ results_iterator = pool .map (
206+ predict_call , split_list (objects , LIMIT_OBJECTS_PER_CALL )
121207 )
122- result .extend (response ["predictions" ])
123- return result
208+
209+ for predictions in results_iterator :
210+ results .extend (predictions )
211+
212+ return results
124213
125214 def create_inference_request_with_url (
126215 self ,
0 commit comments