forked from vaaaaanquish/rust-machine-learning-api-example
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_model.py
26 lines (21 loc) · 786 Bytes
/
create_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
try:
import tensorflow as tf
if tf.__version__.split(".")[0] != "2":
raise Exception("This script requires tensorflow >= 2")
except:
print(
"// -----------------------\n[requirement] pip install tensorflow>=2\n// --------------------------"
)
raise
from tensorflow.python.framework.convert_to_constants import (
convert_variables_to_constants_v2,
)
# default input shape 224x224x3
model = tf.keras.applications.MobileNetV3Small(
input_shape=[224, 224, 3], weights="imagenet"
)
x = tf.TensorSpec(None, tf.float32, name="input")
model_fn = tf.function(model).get_concrete_function(x)
frozen_model = convert_variables_to_constants_v2(model_fn)
directory = "model"
tf.io.write_graph(frozen_model.graph, directory, "model.pb", as_text=False)