diff --git a/model2onnx.py b/model2onnx.py new file mode 100644 index 000000000..256d1e25e --- /dev/null +++ b/model2onnx.py @@ -0,0 +1,32 @@ + +# import the necessary packages +from tensorflow.keras.models import load_model, save_model +import argparse +import tf2onnx +import onnx + +def model2onnx(): + # construct the argument parser and parse the arguments + ap = argparse.ArgumentParser() + ap.add_argument("-m", "--model", type=str, + default="mask_detector.model", + help="path to trained face mask detector model") + ap.add_argument("-o", "--output", type=str, + default='mask_detector.onnx', + help="path to trained face mask detector model") + args = vars(ap.parse_args()) + + + # load the face mask detector model from disk + print("[INFO] loading face mask detector model...") + model = load_model(args["model"]) + onnx_model, _ = tf2onnx.convert.from_keras(model, opset=13) + + onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = '?' + onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_param = '?' + + onnx.save(onnx_model, args['output']) + + +if __name__ == "__main__": + model2onnx() diff --git a/requirements.txt b/requirements.txt index da3e0bb4f..8d98eb31c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,5 @@ scipy==1.6.2 scikit-learn==0.24.1 pillow>=8.3.2 streamlit==0.79.0 +onnx==1.10.1 +tf2onnx==1.9.3