1
1
"""
2
2
Client API for the Inference microservice.
3
3
"""
4
- from typing import List
4
+ from concurrent .futures import ThreadPoolExecutor
5
+ from typing import List , Union
6
+
7
+ from requests import RequestException
5
8
6
9
from sap .aibus .dar .client .base_client import BaseClientWithSession
10
+ from sap .aibus .dar .client .exceptions import DARHTTPException , InvalidWorkerCount
7
11
from sap .aibus .dar .client .inference_constants import InferencePaths
8
12
from sap .aibus .dar .client .util .lists import split_list
9
13
13
17
#: How many labels to predict for a single object by default
14
18
TOP_N = 1
15
19
20
+ # pylint: disable=too-many-arguments
21
+
16
22
17
23
class InferenceClient (BaseClientWithSession ):
18
24
"""
@@ -73,30 +79,53 @@ def do_bulk_inference(
73
79
objects : List [dict ],
74
80
top_n : int = TOP_N ,
75
81
retry : bool = True ,
76
- ) -> List [dict ]:
82
+ worker_count : int = 4 ,
83
+ ) -> List [Union [dict , None ]]:
77
84
"""
78
85
Performs bulk inference for larger collections.
79
86
80
87
For *objects* collections larger than *LIMIT_OBJECTS_PER_CALL*, splits
81
88
the data into several smaller Inference requests.
82
89
90
+ Requests are executed in parallel.
91
+
83
92
Returns the aggregated values of the *predictions* of the original API response
84
- as returned by :meth:`create_inference_request`.
93
+ as returned by :meth:`create_inference_request`. If one of the inference
94
+ requests to the service fails, an artificial prediction object is inserted with
95
+ the `labels` key set to `None` for each of the objects in the failing request.
96
+
97
+ Example of a prediction object which indicates an error:
98
+
99
+ .. code-block:: python
100
+
101
+ {
102
+ 'objectId': 'b5cbcb34-7ab9-4da5-b7ec-654c90757eb9',
103
+ 'labels': None,
104
+ '_sdk_error': 'RequestException: Request Error'
105
+ }
106
+
107
+ In case the `objects` passed to this method do not contain the `objectId` field,
108
+ the value is set to `None` in the error prediction object:
109
+
110
+ .. code-block:: python
111
+
112
+ {
113
+ 'objectId': None,
114
+ 'labels': None,
115
+ '_sdk_error': 'RequestException: Request Error'
116
+ }
117
+
85
118
86
119
.. note::
87
120
88
121
This method calls the inference endpoint multiple times to process all data.
89
122
For non-trial service instances, each call will incur a cost.
90
123
91
- If one of the calls fails, this method will raise an Exception and the
92
- progress will be lost. In this case, all calls until the Exception happened
93
- will be charged.
94
-
95
- To reduce the likelihood of a failed request terminating the bulk inference
96
- process, this method will retry failed requests.
124
+ To reduce the impact of a failed request, this method will retry failed
125
+ requests.
97
126
98
127
There is a small chance that even retried requests will be charged, e.g.
99
- if a problem occurs with the request on the client side outside of the
128
+ if a problem occurs with the request on the client side outside the
100
129
control of the service and after the service has processed the request.
101
130
To disable `retry` behavior simply pass `retry=False` to the method.
102
131
@@ -107,20 +136,80 @@ def do_bulk_inference(
107
136
The default for the `retry` parameter changed from `retry=False` to
108
137
`retry=True` for increased reliability in day-to-day operations.
109
138
139
+ .. versionchanged:: 0.12.0
140
+ Requests are now executed in parallel with up to four threads.
141
+
142
+ Errors are now handled in this method instead of raising an exception and
143
+ discarding inference results from previous requests. For objects where the
144
+ inference request did not succeed, a replacement `dict` object is placed in
145
+ the returned `list`.
146
+ This `dict` follows the format of the `ObjectPrediction` object sent by the
147
+ service. To indicate that this is a client-side generated placeholder, the
148
+ `labels` key for all ObjectPrediction dicts of the failed inference request
149
+ has value `None`.
150
+ A `_sdk_error` key is added with the Exception details.
151
+
152
+ .. versionadded:: 0.12.0
153
+ The `worker_count` parameter allows to fine-tune the number of concurrent
154
+ request threads. Set `worker_count` to `1` to disable concurrent execution of
155
+ requests.
156
+
110
157
111
158
:param model_name: name of the model used for inference
112
159
:param objects: Objects to be classified
113
160
:param top_n: How many predictions to return per object
114
161
:param retry: whether to retry on errors. Default: True
162
+ :param worker_count: maximum number of concurrent requests
163
+ :raises: InvalidWorkerCount if worker_count param is incorrect
115
164
:return: the aggregated ObjectPrediction dictionaries
116
165
"""
117
- result = [] # type: List[dict]
118
- for work_package in split_list (objects , LIMIT_OBJECTS_PER_CALL ):
119
- response = self .create_inference_request (
120
- model_name , work_package , top_n = top_n , retry = retry
166
+
167
+ if worker_count is None :
168
+ raise InvalidWorkerCount ("worker_count cannot be None!" )
169
+
170
+ if worker_count > 4 :
171
+ msg = "worker_count too high: %s. Up to 4 allowed." % worker_count
172
+ raise InvalidWorkerCount (msg )
173
+
174
+ if worker_count <= 0 :
175
+ msg = "worker_count must be greater than 0!"
176
+ raise InvalidWorkerCount (msg )
177
+
178
+ def predict_call (work_package ):
179
+ try :
180
+ response = self .create_inference_request (
181
+ model_name , work_package , top_n = top_n , retry = retry
182
+ )
183
+ return response ["predictions" ]
184
+ except (DARHTTPException , RequestException ) as exc :
185
+ self .log .warning (
186
+ "Caught %s during bulk inference. "
187
+ "Setting results to None for this batch!" ,
188
+ exc ,
189
+ exc_info = True ,
190
+ )
191
+
192
+ prediction_error = [
193
+ {
194
+ "objectId" : inference_object .get ("objectId" , None ),
195
+ "labels" : None ,
196
+ "_sdk_error" : "{}: {}" .format (exc .__class__ .__name__ , str (exc )),
197
+ }
198
+ for inference_object in work_package
199
+ ]
200
+ return prediction_error
201
+
202
+ results = []
203
+
204
+ with ThreadPoolExecutor (max_workers = worker_count ) as pool :
205
+ results_iterator = pool .map (
206
+ predict_call , split_list (objects , LIMIT_OBJECTS_PER_CALL )
121
207
)
122
- result .extend (response ["predictions" ])
123
- return result
208
+
209
+ for predictions in results_iterator :
210
+ results .extend (predictions )
211
+
212
+ return results
124
213
125
214
def create_inference_request_with_url (
126
215
self ,
0 commit comments