-
Notifications
You must be signed in to change notification settings - Fork 6
/
v1_Gesture_Recognition_Serving.py
executable file
·158 lines (108 loc) · 3.92 KB
/
v1_Gesture_Recognition_Serving.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/env python3
"""Simple Flask example to serve Keras model."""
import glob
import json
import os
import numpy as np
from pathlib import Path
from flask import Flask
from flask import abort
from flask import request
from flask_restplus import fields
from flask_restplus import Resource, Api, Namespace
from http import HTTPStatus
import tensorflow as tf
from tensorflow import keras
_HERE = Path(__file__).parent
_ACCELEROMETER_DATA_PATH = Path(_HERE, 'data/accelerometer/xo')
_MODEL_DIR = Path(os.getenv("MODEL_DIR", _HERE / "models/hdf5"))
"""Path to model directory."""
_MODEL: keras.Model = None
"""Keras model."""
app = Flask(__name__)
probe_ns = Namespace('probe', description="Health checks.")
model_ns = Namespace('model', description="Model namespace.")
model_input = model_ns.model('Input', {
'instances': fields.List(
fields.List(fields.Float),
required=True,
description="Model input instances. Tensor of shape (N, 13)",
example=json.loads(
Path(_ACCELEROMETER_DATA_PATH, 'examples/example_instance.json').read_text()),
),
'signature': fields.String(
required=True,
default="serving_default",
description="Signature to be returned my model.",
example="serving_default")
})
model_ns.add_model('model_input', model_input)
@probe_ns.route('/liveness')
class Liveness(Resource):
# noinspection PyMethodMayBeStatic
def get(self):
"""Heartbeat."""
return {'Status': "Running OK."}, HTTPStatus.OK
@probe_ns.route('/readiness')
class Readiness(Resource):
# noinspection PyMethodMayBeStatic
def get(self):
"""Readiness."""
if _MODEL is not None:
response = {'Status': "Ready."}, HTTPStatus.OK
else:
response = {'Status': "Model has not been loaded."}, \
HTTPStatus.SERVICE_UNAVAILABLE
return response
@model_ns.route('/predict')
class Model(Resource):
"""Model api resource."""
@model_ns.expect(model_input, validate=True)
def post(self):
"""Return predictions from the trained model.
Expected input: tensor of shape (13,)
"""
global _SESSION
model = _load_keras_model()
message = request.get_json(force=True)
input_t: np.ndarray = np.array(message['instances'], dtype=np.float64)
tf.keras.backend.set_session(_SESSION)
predictions: np.ndarray = model.predict_on_batch(input_t)
response = {
'predictions': predictions.tolist()
}
return response, HTTPStatus.OK
def _load_keras_model():
"""Load Keras model."""
global _MODEL
global _SESSION
if _MODEL is not None:
return _MODEL
app.logger.info("Loading Keras model.")
# TODO: glob by pattern according to our model file naming
model_files = sorted(
# This serving is made for .h5 models
glob.iglob(str(_MODEL_DIR / '**/*.h5'), recursive=True), key=os.path.getctime)
if not model_files:
msg = f"Empty directory provided: {_MODEL_DIR}."
app.logger.error(msg)
# TODO: maybe BAD_REQUEST is not ideal here
abort(HTTPStatus.BAD_REQUEST, "Failed. Model not found.")
else:
latest = model_files[-1]
# set the global
_MODEL = keras.models.load_model(latest)
if isinstance(_MODEL, keras.Model):
app.logger.info(f"Model '{latest}' successfully loaded.")
else:
msg = f"Expected model of type: {keras.Model}, got {type(_MODEL)}"
app.logger.error(msg)
abort(HTTPStatus.BAD_REQUEST, "Failed. Model not loaded.")
_SESSION = tf.keras.backend.get_session()
return _MODEL
if __name__ == '__main__':
api = Api(title="Gestures model serving")
api.init_app(app)
api.add_namespace(probe_ns)
api.add_namespace(model_ns)
app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 5000)), debug=False)