diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..146a55a
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+tensorflow_model/logs/
+tensorflow_model/MNIST_data/
+tensorflow_model/out/
+
diff --git a/MnistAndroid/.gitignore b/MnistAndroid/.gitignore
index 39fb081..f237317 100644
--- a/MnistAndroid/.gitignore
+++ b/MnistAndroid/.gitignore
@@ -7,3 +7,4 @@
 /build
 /captures
 .externalNativeBuild
+.idea/
diff --git a/MnistAndroid/.idea/compiler.xml b/MnistAndroid/.idea/compiler.xml
deleted file mode 100644
index 96cc43e..0000000
--- a/MnistAndroid/.idea/compiler.xml
+++ /dev/null
@@ -1,22 +0,0 @@
-
-
-  
-    
-    
-      
-      
-      
-      
-      
-      
-      
-      
-      
-    
-    
-      
-        
-      
-    
-  
-
\ No newline at end of file
diff --git a/MnistAndroid/.idea/copyright/profiles_settings.xml b/MnistAndroid/.idea/copyright/profiles_settings.xml
deleted file mode 100644
index e7bedf3..0000000
--- a/MnistAndroid/.idea/copyright/profiles_settings.xml
+++ /dev/null
@@ -1,3 +0,0 @@
-
-  
-
\ No newline at end of file
diff --git a/MnistAndroid/.idea/gradle.xml b/MnistAndroid/.idea/gradle.xml
deleted file mode 100644
index 7ac24c7..0000000
--- a/MnistAndroid/.idea/gradle.xml
+++ /dev/null
@@ -1,18 +0,0 @@
-
-
-  
-    
-  
-
\ No newline at end of file
diff --git a/MnistAndroid/.idea/misc.xml b/MnistAndroid/.idea/misc.xml
deleted file mode 100644
index 5d19981..0000000
--- a/MnistAndroid/.idea/misc.xml
+++ /dev/null
@@ -1,46 +0,0 @@
-
-
-  
-    
-  
-  
-    
-    
-    
-      
-        
-          -
-
-
-
-      
-    
-    
-      
-
-          -
-
-
-
-      
-    
-  
-  
-    
-    
-    
-    
-    
-    
-    
-    
-  
-  
-    
-  
-  
-    
-  
-
\ No newline at end of file
diff --git a/MnistAndroid/.idea/modules.xml b/MnistAndroid/.idea/modules.xml
deleted file mode 100644
index 02e6fb4..0000000
--- a/MnistAndroid/.idea/modules.xml
+++ /dev/null
@@ -1,9 +0,0 @@
-
-
-  
-    
-      
-      
-    
-  
-
\ No newline at end of file
diff --git a/MnistAndroid/.idea/runConfigurations.xml b/MnistAndroid/.idea/runConfigurations.xml
deleted file mode 100644
index 7f68460..0000000
--- a/MnistAndroid/.idea/runConfigurations.xml
+++ /dev/null
@@ -1,12 +0,0 @@
-
-
-  
-    
-      
-        
-        
-        
-      
-    
-  
-
\ No newline at end of file
diff --git a/MnistAndroid/app/CMakeLists.txt b/MnistAndroid/app/CMakeLists.txt
deleted file mode 100644
index f8e6e8b..0000000
--- a/MnistAndroid/app/CMakeLists.txt
+++ /dev/null
@@ -1,44 +0,0 @@
-# For more information about using CMake with Android Studio, read the
-# documentation: https://d.android.com/studio/projects/add-native-code.html
-
-# Sets the minimum version of CMake required to build the native library.
-
-cmake_minimum_required(VERSION 3.4.1)
-
-# Creates and names a library, sets it as either STATIC
-# or SHARED, and provides the relative paths to its source code.
-# You can define multiple libraries, and CMake builds them for you.
-# Gradle automatically packages shared libraries with your APK.
-
-add_library( # Sets the name of the library.
-             native-lib
-
-             # Sets the library as a shared library.
-             SHARED
-
-             # Provides a relative path to your source file(s).
-             src/main/cpp/native-lib.cpp )
-
-# Searches for a specified prebuilt library and stores the path as a
-# variable. Because CMake includes system libraries in the search path by
-# default, you only need to specify the name of the public NDK library
-# you want to add. CMake verifies that the library exists before
-# completing its build.
-
-find_library( # Sets the name of the path variable.
-              log-lib
-
-              # Specifies the name of the NDK library that
-              # you want CMake to locate.
-              log )
-
-# Specifies libraries CMake should link to your target library. You
-# can link multiple libraries, such as libraries you define in this
-# build script, prebuilt third-party libraries, or system libraries.
-
-target_link_libraries( # Specifies the target library.
-                       native-lib
-
-                       # Links the target library to the log library
-                       # included in the NDK.
-                       ${log-lib} )
\ No newline at end of file
diff --git a/MnistAndroid/app/build.gradle b/MnistAndroid/app/build.gradle
index a33663b..3992eef 100644
--- a/MnistAndroid/app/build.gradle
+++ b/MnistAndroid/app/build.gradle
@@ -9,12 +9,6 @@ android {
         targetSdkVersion 25
         versionCode 1
         versionName "1.0"
-        testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
-        externalNativeBuild {
-            cmake {
-                cppFlags ""
-            }
-        }
     }
     buildTypes {
         release {
@@ -22,20 +16,9 @@ android {
             proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
         }
     }
-    externalNativeBuild {
-        cmake {
-            path "CMakeLists.txt"
-        }
-    }
 }
 
 dependencies {
-    compile fileTree(include: ['*.jar'], dir: 'libs')
-    androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
-        exclude group: 'com.android.support', module: 'support-annotations'
-    })
     compile 'com.android.support:appcompat-v7:25.3.1'
-    compile 'com.android.support.constraint:constraint-layout:1.0.2'
-    testCompile 'junit:junit:4.12'
-    compile files('libs/libandroid_tensorflow_inference_java.jar')
+    compile 'org.tensorflow:tensorflow-android:1.2.0-rc0'
 }
diff --git a/MnistAndroid/app/libs/libandroid_tensorflow_inference_java.jar b/MnistAndroid/app/libs/libandroid_tensorflow_inference_java.jar
deleted file mode 100644
index 3b8d93b..0000000
Binary files a/MnistAndroid/app/libs/libandroid_tensorflow_inference_java.jar and /dev/null differ
diff --git a/MnistAndroid/app/src/androidTest/java/mariannelinhares/mnistandroid/ExampleInstrumentedTest.java b/MnistAndroid/app/src/androidTest/java/mariannelinhares/mnistandroid/ExampleInstrumentedTest.java
deleted file mode 100644
index fb64682..0000000
--- a/MnistAndroid/app/src/androidTest/java/mariannelinhares/mnistandroid/ExampleInstrumentedTest.java
+++ /dev/null
@@ -1,26 +0,0 @@
-package mariannelinhares.mnistandroid;
-
-import android.content.Context;
-import android.support.test.InstrumentationRegistry;
-import android.support.test.runner.AndroidJUnit4;
-
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-import static org.junit.Assert.*;
-
-/**
- * Instrumentation test, which will execute on an Android device.
- *
- * @see Testing documentation
- */
-@RunWith(AndroidJUnit4.class)
-public class ExampleInstrumentedTest {
-    @Test
-    public void useAppContext() throws Exception {
-        // Context of the app under test.
-        Context appContext = InstrumentationRegistry.getTargetContext();
-
-        assertEquals("mariannelinhares.mnistandroid", appContext.getPackageName());
-    }
-}
diff --git a/MnistAndroid/app/src/main/assets/expert-graph.pb b/MnistAndroid/app/src/main/assets/opt_mnist_convnet-keras.pb
similarity index 52%
rename from MnistAndroid/app/src/main/assets/expert-graph.pb
rename to MnistAndroid/app/src/main/assets/opt_mnist_convnet-keras.pb
index 3533274..80c7ecd 100644
Binary files a/MnistAndroid/app/src/main/assets/expert-graph.pb and b/MnistAndroid/app/src/main/assets/opt_mnist_convnet-keras.pb differ
diff --git a/MnistAndroid/app/src/main/assets/graph.pb b/MnistAndroid/app/src/main/assets/opt_mnist_convnet-tf.pb
similarity index 52%
rename from MnistAndroid/app/src/main/assets/graph.pb
rename to MnistAndroid/app/src/main/assets/opt_mnist_convnet-tf.pb
index dcd1321..279adec 100644
Binary files a/MnistAndroid/app/src/main/assets/graph.pb and b/MnistAndroid/app/src/main/assets/opt_mnist_convnet-tf.pb differ
diff --git a/MnistAndroid/app/src/main/cpp/native-lib.cpp b/MnistAndroid/app/src/main/cpp/native-lib.cpp
deleted file mode 100644
index cbb9c07..0000000
--- a/MnistAndroid/app/src/main/cpp/native-lib.cpp
+++ /dev/null
@@ -1,11 +0,0 @@
-#include 
-#include 
-
-extern "C"
-JNIEXPORT jstring JNICALL
-Java_mariannelinhares_mnistandroid_MainActivity_stringFromJNI(
-        JNIEnv* env,
-        jobject /* this */) {
-    std::string hello = "Hello from C++";
-    return env->NewStringUTF(hello.c_str());
-}
diff --git a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/MainActivity.java b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/MainActivity.java
index eb086be..dd39deb 100644
--- a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/MainActivity.java
+++ b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/MainActivity.java
@@ -15,7 +15,9 @@
    See the License for the specific language governing permissions and
    limitations under the License.
 
-   From: https://raw.githubusercontent.com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp/narr/tensorflowmnist/DrawModel.java
+   From: https://raw.githubusercontent
+   .com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp/narr/tensorflowmnist
+   /DrawModel.java
 */
 
 import android.app.Activity;
@@ -25,41 +27,32 @@
 import android.view.View;
 import android.widget.Button;
 import android.widget.TextView;
-
-import java.util.concurrent.Executor;
-import java.util.concurrent.Executors;
-
+import java.util.ArrayList;
+import java.util.List;
+import mariannelinhares.mnistandroid.models.Classification;
+import mariannelinhares.mnistandroid.models.Classifier;
+import mariannelinhares.mnistandroid.models.TensorFlowClassifier;
 import mariannelinhares.mnistandroid.views.DrawModel;
 import mariannelinhares.mnistandroid.views.DrawView;
 
 /**
  * Changed by marianne-linhares on 21/04/17.
- * https://raw.githubusercontent.com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp/narr/tensorflowmnist/DrawModel.java
+ * https://raw.githubusercontent.com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp
+ * /narr/tensorflowmnist/DrawModel.java
  */
 
 public class MainActivity extends Activity implements View.OnClickListener, View.OnTouchListener {
 
+    private static final int PIXEL_WIDTH = 28;
+
     // ui related
     private Button clearBtn, classBtn;
     private TextView resText;
-
-    // tensorflow input and output
-    private static final int INPUT_SIZE = 28;
-    private static final String INPUT_NAME = "input";
-    private static final String OUTPUT_NAME = "output";
-
-    private static final String MODEL_FILE = "file:///android_asset/expert-graph.pb";
-    private static final String LABEL_FILE = "file:///android_asset/labels.txt";
-
-    private Classifier classifier;
-
-    private Executor executor = Executors.newSingleThreadExecutor();
+    private List mClassifiers = new ArrayList<>();
 
     // views related
     private DrawModel drawModel;
     private DrawView drawView;
-    private static final int PIXEL_WIDTH = 28;
-
     private PointF mTmpPiont = new PointF();
 
     private float mLastX;
@@ -71,97 +64,84 @@ protected void onCreate(Bundle savedInstanceState) {
         setContentView(R.layout.activity_main);
 
         //get drawing view
-        drawView = (DrawView)findViewById(R.id.draw);
+        drawView = (DrawView) findViewById(R.id.draw);
         drawModel = new DrawModel(PIXEL_WIDTH, PIXEL_WIDTH);
 
         drawView.setModel(drawModel);
         drawView.setOnTouchListener(this);
 
         //clear button
-        clearBtn = (Button)findViewById(R.id.btn_clear);
+        clearBtn = (Button) findViewById(R.id.btn_clear);
         clearBtn.setOnClickListener(this);
 
         //class button
-        classBtn = (Button)findViewById(R.id.btn_class);
+        classBtn = (Button) findViewById(R.id.btn_class);
         classBtn.setOnClickListener(this);
 
         // res text
-        resText = (TextView)findViewById(R.id.tfRes);
+        resText = (TextView) findViewById(R.id.tfRes);
 
         // tensorflow
         loadModel();
     }
 
+    @Override
+    protected void onResume() {
+        drawView.onResume();
+        super.onResume();
+    }
+
+    @Override
+    protected void onPause() {
+        drawView.onPause();
+        super.onPause();
+    }
+
     private void loadModel() {
-        executor.execute(new Runnable() {
+        new Thread(new Runnable() {
             @Override
             public void run() {
                 try {
-                    classifier = Classifier.create(getApplicationContext().getAssets(),
-                                                   MODEL_FILE,
-                                                   LABEL_FILE,
-                                                   INPUT_SIZE,
-                                                   INPUT_NAME,
-                                                   OUTPUT_NAME);
+                    mClassifiers.add(
+                            TensorFlowClassifier.create(getAssets(), "TensorFlow",
+                                    "opt_mnist_convnet-tf.pb", "labels.txt", PIXEL_WIDTH,
+                                    "input", "output", true));
+                    mClassifiers.add(
+                            TensorFlowClassifier.create(getAssets(), "Keras",
+                                    "opt_mnist_convnet-keras.pb", "labels.txt", PIXEL_WIDTH,
+                                    "conv2d_1_input", "dense_2/Softmax", false));
                 } catch (final Exception e) {
-                    throw new RuntimeException("Error initializing TensorFlow!", e);
+                    throw new RuntimeException("Error initializing classifiers!", e);
                 }
             }
-        });
-    }
-
-    /**
-     * A native method that is implemented by the 'native-lib' native library,
-     * which is packaged with this application.
-     */
-    public native String stringFromJNI();
-
-    // Used to load the 'native-lib' library on application startup.
-    static {
-        System.loadLibrary("native-lib");
+        }).start();
     }
 
-
     @Override
-    public void onClick(View view){
-
-        if(view.getId() == R.id.btn_clear) {
+    public void onClick(View view) {
+        if (view.getId() == R.id.btn_clear) {
             drawModel.clear();
             drawView.reset();
             drawView.invalidate();
 
-            resText.setText("Result: ");
-        }
-        else if(view.getId() == R.id.btn_class){
-
+            resText.setText("");
+        } else if (view.getId() == R.id.btn_class) {
             float pixels[] = drawView.getPixelData();
 
-            final Classification res = classifier.recognize(pixels);
-            String result = "Result: ";
-            if (res.getLabel() == null) {
-                resText.setText(result + "?");
-            }
-            else {
-                result += res.getLabel();
-                result += "\nwith probability: " + res.getConf();
-                resText.setText(result);
+            String text = "";
+            for (Classifier classifier : mClassifiers) {
+                final Classification res = classifier.recognize(pixels);
+                if (res.getLabel() == null) {
+                    text += classifier.name() + ": ?\n";
+                } else {
+                    text += String.format("%s: %s, %f\n", classifier.name(), res.getLabel(),
+                            res.getConf());
+                }
             }
+            resText.setText(text);
         }
     }
 
-    @Override
-    protected void onResume() {
-        drawView.onResume();
-        super.onResume();
-    }
-
-    @Override
-    protected void onPause() {
-        drawView.onPause();
-        super.onPause();
-    }
-
-
     @Override
     public boolean onTouch(View v, MotionEvent event) {
         int action = event.getAction() & MotionEvent.ACTION_MASK;
@@ -169,11 +149,9 @@ public boolean onTouch(View v, MotionEvent event) {
         if (action == MotionEvent.ACTION_DOWN) {
             processTouchDown(event);
             return true;
-
         } else if (action == MotionEvent.ACTION_MOVE) {
             processTouchMove(event);
             return true;
-
         } else if (action == MotionEvent.ACTION_UP) {
             processTouchUp();
             return true;
diff --git a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classification.java b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classification.java
similarity index 58%
rename from MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classification.java
rename to MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classification.java
index 95bfb1f..5ba6c0b 100644
--- a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classification.java
+++ b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classification.java
@@ -1,4 +1,4 @@
-package mariannelinhares.mnistandroid;
+package mariannelinhares.mnistandroid.models;
 
 /**
  * Created by marianne-linhares on 20/04/17.
@@ -9,16 +9,12 @@ public class Classification {
     private float conf;
     private String label;
 
-    public Classification(float conf, String label) {
-        update(conf, label);
-    }
-
-    public Classification() {
-        this.conf = (float)-1.0;
+    Classification() {
+        this.conf = -1.0F;
         this.label = null;
     }
 
-    public void update(float conf, String label) {
+    void update(float conf, String label) {
         this.conf = conf;
         this.label = label;
     }
@@ -30,5 +26,4 @@ public String getLabel() {
     public float getConf() {
         return conf;
     }
-
 }
diff --git a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classifier.java b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classifier.java
new file mode 100644
index 0000000..fd551d1
--- /dev/null
+++ b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classifier.java
@@ -0,0 +1,11 @@
+package mariannelinhares.mnistandroid.models;
+
+/**
+ * Created by Piasy{github.com/Piasy} on 29/05/2017.
+ */
+
+public interface Classifier {
+    String name();
+
+    Classification recognize(final float[] pixels);
+}
diff --git a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classifier.java b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/TensorFlowClassifier.java
similarity index 55%
rename from MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classifier.java
rename to MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/TensorFlowClassifier.java
index f9b34ed..a908f67 100644
--- a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classifier.java
+++ b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/TensorFlowClassifier.java
@@ -1,38 +1,38 @@
-package mariannelinhares.mnistandroid;
+package mariannelinhares.mnistandroid.models;
 
 import android.content.res.AssetManager;
-
 import java.io.BufferedReader;
 import java.io.IOException;
 import java.io.InputStreamReader;
 import java.util.ArrayList;
 import java.util.List;
-
 import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
 
 /**
- * Changed from https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/blob/master/app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java
+ * Changed from https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/blob/master
+ * /app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java
  * Created by marianne-linhares on 20/04/17.
  */
 
-public class Classifier {
+public class TensorFlowClassifier implements Classifier {
 
     // Only returns if at least this confidence
     private static final float THRESHOLD = 0.1f;
 
     private TensorFlowInferenceInterface tfHelper;
 
+    private String name;
     private String inputName;
     private String outputName;
     private int inputSize;
+    private boolean feedKeepProb;
 
     private List labels;
     private float[] output;
     private String[] outputNames;
 
-    static private List readLabels(Classifier c, AssetManager am, String fileName) throws IOException {
-        BufferedReader br = null;
-        br = new BufferedReader(new InputStreamReader(am.open(fileName)));
+    private static List readLabels(AssetManager am, String fileName) throws IOException {
+        BufferedReader br = new BufferedReader(new InputStreamReader(am.open(fileName)));
 
         String line;
         List labels = new ArrayList<>();
@@ -44,44 +44,49 @@ static private List readLabels(Classifier c, AssetManager am, String fil
         return labels;
     }
 
+    public static TensorFlowClassifier create(AssetManager assetManager, String name,
+            String modelPath, String labelFile, int inputSize, String inputName, String outputName,
+            boolean feedKeepProb) throws IOException {
+        TensorFlowClassifier c = new TensorFlowClassifier();
 
-    static public Classifier create(AssetManager assetManager, String modelPath, String labelPath,
-                             int inputSize, String inputName, String outputName)
-                             throws IOException {
-
-        Classifier c = new Classifier();
+        c.name = name;
 
         c.inputName = inputName;
         c.outputName = outputName;
 
-        // Read labels
-        String labelFile = labelPath.split("file:///android_asset/")[1];
-        c.labels = readLabels(c, assetManager, labelFile);
-
-        c.tfHelper = new TensorFlowInferenceInterface();
-        if (c.tfHelper.initializeTensorFlow(assetManager, modelPath) != 0) {
-            throw new RuntimeException("TF initialization failed");
-        }
+        c.labels = readLabels(assetManager, labelFile);
 
+        c.tfHelper = new TensorFlowInferenceInterface(assetManager, modelPath);
         int numClasses = 10;
 
         c.inputSize = inputSize;
 
         // Pre-allocate buffer.
-        c.outputNames = new String[]{ outputName };
+        c.outputNames = new String[] { outputName };
 
         c.outputName = outputName;
         c.output = new float[numClasses];
 
+        c.feedKeepProb = feedKeepProb;
+
         return c;
     }
 
+    @Override
+    public String name() {
+        return name;
+    }
+
+    @Override
     public Classification recognize(final float[] pixels) {
 
-        tfHelper.fillNodeFloat(inputName, new int[]{inputSize * inputSize}, pixels);
-        tfHelper.runInference(outputNames);
+        tfHelper.feed(inputName, pixels, 1, inputSize, inputSize, 1);
+        if (feedKeepProb) {
+            tfHelper.feed("keep_prob", new float[] { 1 });
+        }
+        tfHelper.run(outputNames);
 
-        tfHelper.readNodeFloat(outputName, output);
+        tfHelper.fetch(outputName, output);
 
         // Find the best classification
         Classification ans = new Classification();
diff --git a/MnistAndroid/app/src/main/jniLibs/armeabi-v7a/libtensorflow_inference.so b/MnistAndroid/app/src/main/jniLibs/armeabi-v7a/libtensorflow_inference.so
deleted file mode 100644
index 9390465..0000000
Binary files a/MnistAndroid/app/src/main/jniLibs/armeabi-v7a/libtensorflow_inference.so and /dev/null differ
diff --git a/MnistAndroid/app/src/main/jniLibs/x86/libtensorflow_mnist.so b/MnistAndroid/app/src/main/jniLibs/x86/libtensorflow_mnist.so
deleted file mode 100644
index d4572f3..0000000
Binary files a/MnistAndroid/app/src/main/jniLibs/x86/libtensorflow_mnist.so and /dev/null differ
diff --git a/MnistAndroid/app/src/main/res/layout/activity_main.xml b/MnistAndroid/app/src/main/res/layout/activity_main.xml
index e07fe80..b3f727e 100644
--- a/MnistAndroid/app/src/main/res/layout/activity_main.xml
+++ b/MnistAndroid/app/src/main/res/layout/activity_main.xml
@@ -1,47 +1,50 @@
 
 
+        xmlns:android="http://schemas.android.com/apk/res/android"
+        xmlns:tools="http://schemas.android.com/tools"
+        android:layout_width="match_parent"
+        android:layout_height="match_parent"
+        android:orientation="vertical"
+        android:paddingBottom="@dimen/activity_vertical_margin"
+        android:paddingLeft="@dimen/activity_horizontal_margin"
+        android:paddingRight="@dimen/activity_horizontal_margin"
+        android:paddingTop="@dimen/activity_vertical_margin"
+        tools:context="mariannelinhares.mnistandroid.MainActivity"
+        >
 
     
+            android:id="@+id/draw"
+            android:layout_width="match_parent"
+            android:layout_height="0dp"
+            android:layout_weight="1"
+            />
 
     
+            android:layout_width="match_parent"
+            android:layout_height="wrap_content"
+            android:orientation="horizontal"
+            >
 
         
+                android:id="@+id/btn_clear"
+                android:layout_width="wrap_content"
+                android:layout_height="wrap_content"
+                android:text="Clear"
+                />
 
         
+                android:id="@+id/btn_class"
+                android:layout_width="wrap_content"
+                android:layout_height="wrap_content"
+                android:text="Detect"
+                />
+    
 
-        
-    
-
+            android:textAppearance="?android:attr/textAppearanceMedium"
+            />
 
\ No newline at end of file
diff --git a/MnistAndroid/app/src/test/java/mariannelinhares/mnistandroid/ExampleUnitTest.java b/MnistAndroid/app/src/test/java/mariannelinhares/mnistandroid/ExampleUnitTest.java
deleted file mode 100644
index 2a0350d..0000000
--- a/MnistAndroid/app/src/test/java/mariannelinhares/mnistandroid/ExampleUnitTest.java
+++ /dev/null
@@ -1,17 +0,0 @@
-package mariannelinhares.mnistandroid;
-
-import org.junit.Test;
-
-import static org.junit.Assert.*;
-
-/**
- * Example local unit test, which will execute on the development machine (host).
- *
- * @see Testing documentation
- */
-public class ExampleUnitTest {
-    @Test
-    public void addition_isCorrect() throws Exception {
-        assertEquals(4, 2 + 2);
-    }
-}
\ No newline at end of file
diff --git a/MnistAndroid/build.gradle b/MnistAndroid/build.gradle
index b78a0b8..7033e6a 100644
--- a/MnistAndroid/build.gradle
+++ b/MnistAndroid/build.gradle
@@ -5,7 +5,7 @@ buildscript {
         jcenter()
     }
     dependencies {
-        classpath 'com.android.tools.build:gradle:2.3.1'
+        classpath 'com.android.tools.build:gradle:2.3.2'
 
         // NOTE: Do not place your application dependencies here; they belong
         // in the individual module build.gradle files
@@ -15,6 +15,9 @@ buildscript {
 allprojects {
     repositories {
         jcenter()
+        flatDir {
+            dirs "$rootProject.projectDir/aars"
+        }
     }
 }
 
diff --git a/MnistAndroid/gradle/wrapper/gradle-wrapper.properties b/MnistAndroid/gradle/wrapper/gradle-wrapper.properties
index 45120ad..0caac1b 100644
--- a/MnistAndroid/gradle/wrapper/gradle-wrapper.properties
+++ b/MnistAndroid/gradle/wrapper/gradle-wrapper.properties
@@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME
 distributionPath=wrapper/dists
 zipStoreBase=GRADLE_USER_HOME
 zipStorePath=wrapper/dists
-distributionUrl=https\://services.gradle.org/distributions/gradle-3.3-all.zip
+distributionUrl=https\://services.gradle.org/distributions/gradle-3.5-all.zip
diff --git a/README.md b/README.md
index 192b8d3..bd4ac5c 100644
--- a/README.md
+++ b/README.md
@@ -12,9 +12,6 @@ how to save your model and export it for Android or other devices check the
 very simple tutorial bellow.  
 
 The UI and expert-graph.pb model were taken from: https://github.com/miyosuda/TensorFlowAndroidMNIST, so thank you miyousuda.  
-The TensorFlow jar and so armeabi-v7a were taken from: https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample,
-so thank you MindorksOpenSource.  
-The Tensorflow so of x86 was taken from: https://github.com/cesardelgadof/TensorFlowAndroidMNIST, so thank you cesardelgadof.  
 
 If you have no ideia what I just said above, have a look on the instructions bellow.
 
@@ -32,7 +29,7 @@ A full example can be seen [here](https://github.com/mari-linhares/mnist-android
    Example: `_w = sess.eval(w)`, where w was learned from training.
 3. Rewrite your model changing the variables for constants with value = in memory copy of learned variables.
    Example: `w_save = tf.constant(_w)`  
-   
+
    Also make sure to put names in the input and output of the model, this will be needed for the model later.
    Example:  
    `x = tf.placeholder(tf.float32, [None, 1000], name='input')`  
@@ -42,23 +39,10 @@ A full example can be seen [here](https://github.com/mari-linhares/mnist-android
 
 ## How to run my model with Android?
 
-You need two things:
-
-1. [The TensorFlow jar](https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/blob/master/app/libs/libandroid_tensorflow_inference_java.jar)  
-   Move it to the libs folder, right click and add as library.  
-
-2. The TensorFlow so file for the desired architecture:  
-[x86](https://github.com/cesardelgadof/TensorFlowAndroidMNIST/blob/master/app/src/main/jniLibs/x86/libtensorflow_mnist.so)  
-[armeabi-v7a](https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/tree/master/app/src/main/jniLibs/armeabi-v7a)  
-
-Creat the jniLibs/x86 folder or the jniLibs/armeabi-v7a folder at the main folder.  
-Move it to app/src/main/jniLibs/x86/libtensorflow_inference.so or app/src/jniLibs/armeabi-v7a/libtensorflow_inference.so
-
-If you want to generate these files yourself, [here](https://blog.mindorks.com/android-tensorflow-machine-learning-example-ff0e9b2654cc) is a nice tutorial of how to do it.
+You need `tensorflow.aar`, which can be downloaded from [the nightly build artifact of TensorFlow CI](http://ci.tensorflow.org/view/Nightly/job/nightly-android/), here we use [the #124 build](http://ci.tensorflow.org/view/Nightly/job/nightly-android/124/artifact/).
 
 ## Interacting with TensorFlow
 
 To interact with TensorFlow you will need an instance of TensorFlowInferenceInterface, you can see more details about it [here](https://github.com/mari-linhares/mnist-android-tensorflow/blob/master/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classifier.java)
 
 Thank you, have fun!
-
diff --git a/tensorflow_model/convnet.py b/tensorflow_model/convnet.py
deleted file mode 100644
index 0ebdab1..0000000
--- a/tensorflow_model/convnet.py
+++ /dev/null
@@ -1,205 +0,0 @@
-# needed libraries
-import tensorflow as tf
-
-from tensorflow.examples.tutorials.mnist import input_data
-
-logs_path = '/tmp/tensorflow_logs/convnet'
-
-# mnist.train = 55,000 input data
-# mnist.test = 10,000 input data
-# mnist.validate = 5,000 input data
-mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
-
-# Implementing Convnet with TF
-def weight_variable(shape, name=None):
-    # break simmetry
-    if name:
-        w = tf.truncated_normal(shape, stddev=0.1, name=name)
-    else:
-        w = tf.truncated_normal(shape, stddev=0.1)
-
-    return tf.Variable(w)
-
-
-def bias_variable(shape, name=None):
-    # avoid dead neurons
-    if name:
-        b = tf.constant(0.1, shape=shape, name=name)
-    else:
-        b = tf.constant(0.1, shape=shape)
-    return tf.Variable(b)
-
-
-# pool
-def max_pool_2x2(x):
-    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
-                          strides=[1, 2, 2, 1], padding='SAME')
-
-def new_conv_layer(x, w):
-	return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')
-
-# our network!!!
-
-g = tf.Graph()
-
-with g.as_default():
-
-	# input data
-	x = tf.placeholder(tf.float32, shape=[None, 28*28], name='input_data')
-	x_image = tf.reshape(x, [-1, 28, 28, 1])
-	# correct labels
-	y_ = tf.placeholder(tf.float32, shape=[None, 10], name='correct_labels')
-
-	# fist conv layer
-	with tf.name_scope('convLayer1'):
-		w1 = weight_variable([5, 5, 1, 32])
-		b1 = bias_variable([32])
-		convlayer1 = tf.nn.relu(new_conv_layer(x_image, w1) + b1)
-		max_pool1 = max_pool_2x2(convlayer1)
-
-	# second conv layer
-	with tf.name_scope('convLayer2'):
-		w2 = weight_variable([5, 5, 32, 64])
-		b2 = bias_variable([64])
-		convlayer2 = tf.nn.relu(new_conv_layer(max_pool1, w2) + b2)
-		max_pool2 = max_pool_2x2(convlayer2)
-
-	# flat layer
-	with tf.name_scope('flattenLayer'):
-		flat_layer = tf.reshape(max_pool2, [-1, 7 * 7 * 64])
-
-	# fully connected layer
-	with tf.name_scope('FullyConnectedLayer'):
-		wfc1 = weight_variable([7 * 7 * 64, 1024])
-		bfc1 = bias_variable([1024])
-		fc1 = tf.nn.relu(tf.matmul(flat_layer, wfc1) + bfc1)
-
-	# DROPOUT
-	with tf.name_scope('Dropout'):
-		keep_prob = tf.placeholder(tf.float32)
-		drop_layer = tf.nn.dropout(fc1, keep_prob)
-
-	# final layer
-	with tf.name_scope('FinalLayer'):
-		w_f = weight_variable([1024, 10])
-		b_f = bias_variable([10])
-		y_f = tf.matmul(drop_layer, w_f) + b_f
-		y_f_softmax = tf.nn.softmax(y_f)
-
-	# loss
-	loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_,
-																  logits=y_f))
-
-	# train step
-	train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
-
-	# accuracy
-	correct_prediction = tf.equal(tf.argmax(y_f_softmax, 1), tf.argmax(y_, 1))
-	accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
-
-	# Create a summary to monitor loss tensor
-	tf.summary.scalar("loss", loss)
-	# Create a summary to monitor accuracy tensor
-	tf.summary.scalar("accuracy", accuracy)
-	# Merge all summaries into a single op
-	merged_summary_op = tf.summary.merge_all()
-
-	# init
-	init = tf.global_variables_initializer()
-
-	# Running the graph
-
-	num_steps = 3000
-	batch_size = 16
-	test_size = 10000
-	test_accuracy = 0.0
-
-	sess = tf.Session()
-
-	sess.run(init)
-	# op to write logs to Tensorboard
-	summary_writer = tf.summary.FileWriter(logs_path,
-										   graph=tf.get_default_graph())
-
-	for step in range(num_steps):
-		batch = mnist.train.next_batch(batch_size)
-
-		ts, error, acc, summary = sess.run([train_step, loss, accuracy,
-											merged_summary_op],
-										   feed_dict={x: batch[0],
-													  y_: batch[1],
-													  keep_prob: 0.5})
-		if step % 100 == 0:
-			train_accuracy = accuracy.eval({
-				x: batch[0], y_: batch[1], keep_prob: 1.0}, sess)
-			print('step %d, training accuracy %f' % (step, train_accuracy))
-        '''
-	print 'Done!'
-	print 'Evaluating...'
-	for i in xrange(test_size/50):
-		batch = mnist.test.next_batch(50)
-		acc = accuracy.eval({x: batch[0], y_: batch[1],
-									   keep_prob: 1.0}, sess)
-		if i % 10 == 0:
-			print('%d: test accuracy %f' % (i, acc))
-		test_accuracy += acc
-	print 'avg test accuracy:', test_accuracy/(test_size/50.0)
-        '''
-
-# copying variables as constants to export graph
-_w1 = w1.eval(sess)
-_b1 = b1.eval(sess)
-_w2 = w2.eval(sess)
-_b2 = b2.eval(sess)
-_wfc1 = wfc1.eval(sess)
-_bfc1 = bfc1.eval(sess)
-_w_f = w_f.eval(sess)
-_b_f = b_f.eval(sess)
-
-sess.close()
-
-g2 = tf.Graph()
-with g2.as_default():
-
-	# input data
-	x2 = tf.placeholder(tf.float32, shape=[None, 28*28], name='input')
-	x2_image = tf.reshape(x2, [-1, 28, 28, 1])
-	# correct labels
-	y2_ = tf.placeholder(tf.float32, shape=[None, 10])
-
-	w1_2 = tf.constant(_w1)
-	b1_2 = tf.constant(_b1)
-	convlayer1_2 = tf.nn.relu(new_conv_layer(x2_image, w1_2) + b1_2)
-	max_pool1_2 = max_pool_2x2(convlayer1_2)
-
-	w2_2 = tf.constant(_w2)
-	b2_2 = tf.constant(_b2)
-	convlayer2_2 = tf.nn.relu(new_conv_layer(max_pool1_2, w2_2) + b2_2)
-	max_pool2_2 = max_pool_2x2(convlayer2_2)
-
-	# flat layer
-	flat_layer_2 = tf.reshape(max_pool2_2, [-1, 7 * 7 * 64])
-
-	# fully connected layer
-	wfc1_2 = tf.constant(_wfc1)
-	bfc1_2 = tf.constant(_bfc1)
-	fc1_2 = tf.nn.relu(tf.matmul(flat_layer_2, wfc1_2) + bfc1_2)
-
-	# no dropout layer
-
-	# final layer
-	w_f_2 = tf.constant(_w_f)
-	b_f_2 = tf.constant(_b_f)
-	y_f_2 = tf.matmul(fc1_2, w_f_2) + b_f_2
-	y_f_softmax_2 = tf.nn.softmax(y_f_2, name='output')
-
-	# init
-	init_2 = tf.global_variables_initializer()
-
-	sess_2 = tf.Session()
-        init_2 = tf.initialize_all_variables()
-        sess_2.run(init_2)
-
-        graph_def = g2.as_graph_def()
-        tf.train.write_graph(graph_def, '', 'graph.pb', as_text=False)
-
diff --git a/tensorflow_model/graph.pb b/tensorflow_model/graph.pb
deleted file mode 100644
index dcd1321..0000000
Binary files a/tensorflow_model/graph.pb and /dev/null differ
diff --git a/tensorflow_model/mnist_convnet.py b/tensorflow_model/mnist_convnet.py
new file mode 100644
index 0000000..e38f874
--- /dev/null
+++ b/tensorflow_model/mnist_convnet.py
@@ -0,0 +1,140 @@
+# Python 3.6.0
+# tensorflow 1.1.0
+
+import os
+import os.path as path
+
+import tensorflow as tf
+from tensorflow.python.tools import freeze_graph
+from tensorflow.python.tools import optimize_for_inference_lib
+
+from tensorflow.examples.tutorials.mnist import input_data
+
+MODEL_NAME = 'mnist_convnet'
+NUM_STEPS = 3000
+BATCH_SIZE = 16
+
+def model_input(input_node_name, keep_prob_node_name):
+    x = tf.placeholder(tf.float32, shape=[None, 28*28], name=input_node_name)
+    keep_prob = tf.placeholder(tf.float32, name=keep_prob_node_name)
+    y_ = tf.placeholder(tf.float32, shape=[None, 10])
+    return x, keep_prob, y_
+
+def build_model(x, keep_prob, y_, output_node_name):
+    x_image = tf.reshape(x, [-1, 28, 28, 1])
+    # 28*28*1
+
+    conv1 = tf.layers.conv2d(x_image, 64, 3, 1, 'same', activation=tf.nn.relu)
+    # 28*28*64
+    pool1 = tf.layers.max_pooling2d(conv1, 2, 2, 'same')
+    # 14*14*64
+
+    conv2 = tf.layers.conv2d(pool1, 128, 3, 1, 'same', activation=tf.nn.relu)
+    # 14*14*128
+    pool2 = tf.layers.max_pooling2d(conv2, 2, 2, 'same')
+    # 7*7*128
+
+    conv3 = tf.layers.conv2d(pool2, 256, 3, 1, 'same', activation=tf.nn.relu)
+    # 7*7*256
+    pool3 = tf.layers.max_pooling2d(conv3, 2, 2, 'same')
+    # 4*4*256
+
+    flatten = tf.reshape(pool3, [-1, 4*4*256])
+    fc = tf.layers.dense(flatten, 1024, activation=tf.nn.relu)
+    dropout = tf.nn.dropout(fc, keep_prob)
+    logits = tf.layers.dense(dropout, 10)
+    outputs = tf.nn.softmax(logits, name=output_node_name)
+
+    # loss
+    loss = tf.reduce_mean(
+        tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits))
+
+    # train step
+    train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
+
+    # accuracy
+    correct_prediction = tf.equal(tf.argmax(outputs, 1), tf.argmax(y_, 1))
+    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+
+    tf.summary.scalar("loss", loss)
+    tf.summary.scalar("accuracy", accuracy)
+    merged_summary_op = tf.summary.merge_all()
+
+    return train_step, loss, accuracy, merged_summary_op
+
+def train(x, keep_prob, y_, train_step, loss, accuracy,
+        merged_summary_op, saver):
+    print("training start...")
+
+    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
+
+    init_op = tf.global_variables_initializer()
+
+    with tf.Session() as sess:
+        sess.run(init_op)
+
+        tf.train.write_graph(sess.graph_def, 'out',
+            MODEL_NAME + '.pbtxt', True)
+
+        # op to write logs to Tensorboard
+        summary_writer = tf.summary.FileWriter('logs/',
+            graph=tf.get_default_graph())
+
+        for step in range(NUM_STEPS):
+            batch = mnist.train.next_batch(BATCH_SIZE)
+            if step % 100 == 0:
+                train_accuracy = accuracy.eval(feed_dict={
+                    x: batch[0], y_: batch[1], keep_prob: 1.0})
+                print('step %d, training accuracy %f' % (step, train_accuracy))
+            _, summary = sess.run([train_step, merged_summary_op],
+                feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
+            summary_writer.add_summary(summary, step)
+
+        saver.save(sess, 'out/' + MODEL_NAME + '.chkp')
+
+        test_accuracy = accuracy.eval(feed_dict={x: mnist.test.images,
+                                    y_: mnist.test.labels,
+                                    keep_prob: 1.0})
+        print('test accuracy %g' % test_accuracy)
+
+    print("training finished!")
+
+def export_model(input_node_names, output_node_name):
+    freeze_graph.freeze_graph('out/' + MODEL_NAME + '.pbtxt', None, False,
+        'out/' + MODEL_NAME + '.chkp', output_node_name, "save/restore_all",
+        "save/Const:0", 'out/frozen_' + MODEL_NAME + '.pb', True, "")
+
+    input_graph_def = tf.GraphDef()
+    with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
+        input_graph_def.ParseFromString(f.read())
+
+    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
+            input_graph_def, input_node_names, [output_node_name],
+            tf.float32.as_datatype_enum)
+
+    with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
+        f.write(output_graph_def.SerializeToString())
+
+    print("graph saved!")
+
+def main():
+    if not path.exists('out'):
+        os.mkdir('out')
+
+    input_node_name = 'input'
+    keep_prob_node_name = 'keep_prob'
+    output_node_name = 'output'
+
+    x, keep_prob, y_ = model_input(input_node_name, keep_prob_node_name)
+
+    train_step, loss, accuracy, merged_summary_op = build_model(x, keep_prob,
+        y_, output_node_name)
+    saver = tf.train.Saver()
+
+    train(x, keep_prob, y_, train_step, loss, accuracy,
+        merged_summary_op, saver)
+
+    export_model([input_node_name, keep_prob_node_name], output_node_name)
+
+if __name__ == '__main__':
+    main()
diff --git a/tensorflow_model/mnist_convnet_keras.py b/tensorflow_model/mnist_convnet_keras.py
new file mode 100644
index 0000000..8d1fdfb
--- /dev/null
+++ b/tensorflow_model/mnist_convnet_keras.py
@@ -0,0 +1,116 @@
+# Python 3.6.0
+# tensorflow 1.1.0
+# Keras 2.0.4
+
+import os
+import os.path as path
+
+import keras
+from keras.datasets import mnist
+from keras.models import Sequential
+from keras.layers import Input, Dense, Dropout, Flatten
+from keras.layers import Conv2D, MaxPooling2D
+from keras import backend as K
+
+import tensorflow as tf
+from tensorflow.python.tools import freeze_graph
+from tensorflow.python.tools import optimize_for_inference_lib
+
+MODEL_NAME = 'mnist_convnet'
+EPOCHS = 1
+BATCH_SIZE = 128
+
+
+def load_data():
+    (x_train, y_train), (x_test, y_test) = mnist.load_data()
+    x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
+    x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
+    x_train = x_train.astype('float32')
+    x_test = x_test.astype('float32')
+    x_train /= 255
+    x_test /= 255
+    y_train = keras.utils.to_categorical(y_train, 10)
+    y_test = keras.utils.to_categorical(y_test, 10)
+    return x_train, y_train, x_test, y_test
+
+
+def build_model():
+    model = Sequential()
+    model.add(Conv2D(filters=64, kernel_size=3, strides=1, \
+            padding='same', activation='relu', \
+            input_shape=[28, 28, 1]))
+    # 28*28*64
+    model.add(MaxPooling2D(pool_size=2, strides=2, padding='same'))
+    # 14*14*64
+
+    model.add(Conv2D(filters=128, kernel_size=3, strides=1, \
+            padding='same', activation='relu'))
+    # 14*14*128
+    model.add(MaxPooling2D(pool_size=2, strides=2, padding='same'))
+    # 7*7*128
+
+    model.add(Conv2D(filters=256, kernel_size=3, strides=1, \
+            padding='same', activation='relu'))
+    # 7*7*256
+    model.add(MaxPooling2D(pool_size=2, strides=2, padding='same'))
+    # 4*4*256
+
+    model.add(Flatten())
+    model.add(Dense(1024, activation='relu'))
+    #model.add(Dropout(0.5))
+    model.add(Dense(10, activation='softmax'))
+    return model
+
+
+def train(model, x_train, y_train, x_test, y_test):
+    model.compile(loss=keras.losses.categorical_crossentropy, \
+                  optimizer=keras.optimizers.Adadelta(), \
+                  metrics=['accuracy'])
+
+    model.fit(x_train, y_train, \
+              batch_size=BATCH_SIZE, \
+              epochs=EPOCHS, \
+              verbose=1, \
+              validation_data=(x_test, y_test))
+
+
+def export_model(saver, model, input_node_names, output_node_name):
+    tf.train.write_graph(K.get_session().graph_def, 'out', \
+        MODEL_NAME + '_graph.pbtxt')
+
+    saver.save(K.get_session(), 'out/' + MODEL_NAME + '.chkp')
+
+    freeze_graph.freeze_graph('out/' + MODEL_NAME + '_graph.pbtxt', None, \
+        False, 'out/' + MODEL_NAME + '.chkp', output_node_name, \
+        "save/restore_all", "save/Const:0", \
+        'out/frozen_' + MODEL_NAME + '.pb', True, "")
+
+    input_graph_def = tf.GraphDef()
+    with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
+        input_graph_def.ParseFromString(f.read())
+
+    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
+            input_graph_def, input_node_names, [output_node_name],
+            tf.float32.as_datatype_enum)
+
+    with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
+        f.write(output_graph_def.SerializeToString())
+
+    print("graph saved!")
+
+
+def main():
+    if not path.exists('out'):
+        os.mkdir('out')
+
+    x_train, y_train, x_test, y_test = load_data()
+
+    model = build_model()
+
+    train(model, x_train, y_train, x_test, y_test)
+
+    export_model(tf.train.Saver(), model, ["conv2d_1_input"], "dense_2/Softmax")
+
+
+if __name__ == '__main__':
+    main()