-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathinference_client.py
244 lines (196 loc) · 8.98 KB
/
inference_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
"""
Client API for the Inference microservice.
"""
from concurrent.futures import ThreadPoolExecutor
from typing import List, Union
from requests import RequestException
from sap.aibus.dar.client.base_client import BaseClientWithSession
from sap.aibus.dar.client.exceptions import DARHTTPException, InvalidWorkerCount
from sap.aibus.dar.client.inference_constants import InferencePaths
from sap.aibus.dar.client.util.lists import split_list
#: How many objects can be processed per inference request
LIMIT_OBJECTS_PER_CALL = 50
#: How many labels to predict for a single object by default
TOP_N = 1
# pylint: disable=too-many-arguments
class InferenceClient(BaseClientWithSession):
"""
A client for the DAR Inference microservice.
This class implements all basic API calls as well as some convenience methods
which wrap individual API calls.
If the API call fails, all methods will raise an :exc:`DARHTTPException`.
"""
def create_inference_request(
self,
model_name: str,
objects: List[dict],
top_n: int = TOP_N,
retry: bool = False,
) -> dict:
"""
Performs inference for the given *objects* with *model_name*.
For each object in *objects*, returns the *topN* best predictions.
The *retry* parameter determines whether to retry on HTTP errors indicated by
the remote API endpoint or for other connection problems. See :ref:`retry` for
trade-offs involved here.
.. note::
This endpoint called by this method has a limit of *LIMIT_OBJECTS_PER_CALL*
on the number of *objects*. See :meth:`do_bulk_inference` to circumvent
this limit.
:param model_name: name of the model used for inference
:param objects: Objects to be classified
:param top_n: How many predictions to return per object
:param retry: whether to retry on errors. Default: False
:return: API response
"""
self.log.debug(
"Submitting Inference request for model '%s' with '%s'"
" objects and top_n '%s' ",
model_name,
len(objects),
top_n,
)
endpoint = InferencePaths.format_inference_endpoint_by_name(model_name)
response = self.session.post_to_endpoint(
endpoint, payload={"topN": top_n, "objects": objects}, retry=retry
)
as_json = response.json()
self.log.debug("Inference response ID: %s", as_json["id"])
return as_json
def do_bulk_inference(
self,
model_name: str,
objects: List[dict],
top_n: int = TOP_N,
retry: bool = True,
worker_count: int = 4,
) -> List[Union[dict, None]]:
"""
Performs bulk inference for larger collections.
For *objects* collections larger than *LIMIT_OBJECTS_PER_CALL*, splits
the data into several smaller Inference requests.
Requests are executed in parallel.
Returns the aggregated values of the *predictions* of the original API response
as returned by :meth:`create_inference_request`. If one of the inference
requests to the service fails, an artificial prediction object is inserted with
the `labels` key set to `None` for each of the objects in the failing request.
Example of a prediction object which indicates an error:
.. code-block:: python
{
'objectId': 'b5cbcb34-7ab9-4da5-b7ec-654c90757eb9',
'labels': None,
'_sdk_error': 'RequestException: Request Error'
}
In case the `objects` passed to this method do not contain the `objectId` field,
the value is set to `None` in the error prediction object:
.. code-block:: python
{
'objectId': None,
'labels': None,
'_sdk_error': 'RequestException: Request Error'
}
.. note::
This method calls the inference endpoint multiple times to process all data.
For non-trial service instances, each call will incur a cost.
To reduce the impact of a failed request, this method will retry failed
requests.
There is a small chance that even retried requests will be charged, e.g.
if a problem occurs with the request on the client side outside the
control of the service and after the service has processed the request.
To disable `retry` behavior simply pass `retry=False` to the method.
Typically, the default behavior of `retry=True` is safe and improves
reliability of bulk inference greatly.
.. versionchanged:: 0.7.0
The default for the `retry` parameter changed from `retry=False` to
`retry=True` for increased reliability in day-to-day operations.
.. versionchanged:: 0.12.0
Requests are now executed in parallel with up to four threads.
Errors are now handled in this method instead of raising an exception and
discarding inference results from previous requests. For objects where the
inference request did not succeed, a replacement `dict` object is placed in
the returned `list`.
This `dict` follows the format of the `ObjectPrediction` object sent by the
service. To indicate that this is a client-side generated placeholder, the
`labels` key for all ObjectPrediction dicts of the failed inference request
has value `None`.
A `_sdk_error` key is added with the Exception details.
.. versionadded:: 0.12.0
The `worker_count` parameter allows to fine-tune the number of concurrent
request threads. Set `worker_count` to `1` to disable concurrent execution of
requests.
:param model_name: name of the model used for inference
:param objects: Objects to be classified
:param top_n: How many predictions to return per object
:param retry: whether to retry on errors. Default: True
:param worker_count: maximum number of concurrent requests
:raises: InvalidWorkerCount if worker_count param is incorrect
:return: the aggregated ObjectPrediction dictionaries
"""
if worker_count is None:
raise InvalidWorkerCount("worker_count cannot be None!")
if worker_count > 4:
msg = "worker_count too high: %s. Up to 4 allowed." % worker_count
raise InvalidWorkerCount(msg)
if worker_count <= 0:
msg = "worker_count must be greater than 0!"
raise InvalidWorkerCount(msg)
def predict_call(work_package):
try:
response = self.create_inference_request(
model_name, work_package, top_n=top_n, retry=retry
)
return response["predictions"]
except (DARHTTPException, RequestException) as exc:
self.log.warning(
"Caught %s during bulk inference. "
"Setting results to None for this batch!",
exc,
exc_info=True,
)
prediction_error = [
{
"objectId": inference_object.get("objectId", None),
"labels": None,
"_sdk_error": "{}: {}".format(exc.__class__.__name__, str(exc)),
}
for inference_object in work_package
]
return prediction_error
results = []
with ThreadPoolExecutor(max_workers=worker_count) as pool:
results_iterator = pool.map(
predict_call, split_list(objects, LIMIT_OBJECTS_PER_CALL)
)
for predictions in results_iterator:
results.extend(predictions)
return results
def create_inference_request_with_url(
self,
url: str,
objects: List[dict],
top_n: int = TOP_N,
retry: bool = False,
) -> dict:
"""
Performs inference for the given *objects* against fully-qualified URL.
A complete inference URL can be the passed to the method inference, instead
of constructing URL from using base url and model name
:param url: fully-qualified inference URL
:param objects: Objects to be classified
:param top_n: How many predictions to return per object
:param retry: whether to retry on errors. Default: False
:return: API response
"""
self.log.debug(
"Submitting Inference request with '%s'"
" objects and top_n '%s' to url %s",
len(objects),
top_n,
url,
)
response = self.session.post_to_url(
url, payload={"topN": top_n, "objects": objects}, retry=retry
)
as_json = response.json()
self.log.debug("Inference response ID: %s", as_json["id"])
return as_json