diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..146a55a --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +tensorflow_model/logs/ +tensorflow_model/MNIST_data/ +tensorflow_model/out/ + diff --git a/MnistAndroid/.gitignore b/MnistAndroid/.gitignore index 39fb081..f237317 100644 --- a/MnistAndroid/.gitignore +++ b/MnistAndroid/.gitignore @@ -7,3 +7,4 @@ /build /captures .externalNativeBuild +.idea/ diff --git a/MnistAndroid/.idea/compiler.xml b/MnistAndroid/.idea/compiler.xml deleted file mode 100644 index 96cc43e..0000000 --- a/MnistAndroid/.idea/compiler.xml +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/MnistAndroid/.idea/copyright/profiles_settings.xml b/MnistAndroid/.idea/copyright/profiles_settings.xml deleted file mode 100644 index e7bedf3..0000000 --- a/MnistAndroid/.idea/copyright/profiles_settings.xml +++ /dev/null @@ -1,3 +0,0 @@ - - - \ No newline at end of file diff --git a/MnistAndroid/.idea/gradle.xml b/MnistAndroid/.idea/gradle.xml deleted file mode 100644 index 7ac24c7..0000000 --- a/MnistAndroid/.idea/gradle.xml +++ /dev/null @@ -1,18 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/MnistAndroid/.idea/misc.xml b/MnistAndroid/.idea/misc.xml deleted file mode 100644 index 5d19981..0000000 --- a/MnistAndroid/.idea/misc.xml +++ /dev/null @@ -1,46 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/MnistAndroid/.idea/modules.xml b/MnistAndroid/.idea/modules.xml deleted file mode 100644 index 02e6fb4..0000000 --- a/MnistAndroid/.idea/modules.xml +++ /dev/null @@ -1,9 +0,0 @@ - - - - - - - - - \ No newline at end of file diff --git a/MnistAndroid/.idea/runConfigurations.xml b/MnistAndroid/.idea/runConfigurations.xml deleted file mode 100644 index 7f68460..0000000 --- a/MnistAndroid/.idea/runConfigurations.xml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/MnistAndroid/app/CMakeLists.txt b/MnistAndroid/app/CMakeLists.txt deleted file mode 100644 index f8e6e8b..0000000 --- a/MnistAndroid/app/CMakeLists.txt +++ /dev/null @@ -1,44 +0,0 @@ -# For more information about using CMake with Android Studio, read the -# documentation: https://d.android.com/studio/projects/add-native-code.html - -# Sets the minimum version of CMake required to build the native library. - -cmake_minimum_required(VERSION 3.4.1) - -# Creates and names a library, sets it as either STATIC -# or SHARED, and provides the relative paths to its source code. -# You can define multiple libraries, and CMake builds them for you. -# Gradle automatically packages shared libraries with your APK. - -add_library( # Sets the name of the library. - native-lib - - # Sets the library as a shared library. - SHARED - - # Provides a relative path to your source file(s). - src/main/cpp/native-lib.cpp ) - -# Searches for a specified prebuilt library and stores the path as a -# variable. Because CMake includes system libraries in the search path by -# default, you only need to specify the name of the public NDK library -# you want to add. CMake verifies that the library exists before -# completing its build. - -find_library( # Sets the name of the path variable. - log-lib - - # Specifies the name of the NDK library that - # you want CMake to locate. - log ) - -# Specifies libraries CMake should link to your target library. You -# can link multiple libraries, such as libraries you define in this -# build script, prebuilt third-party libraries, or system libraries. - -target_link_libraries( # Specifies the target library. - native-lib - - # Links the target library to the log library - # included in the NDK. - ${log-lib} ) \ No newline at end of file diff --git a/MnistAndroid/app/build.gradle b/MnistAndroid/app/build.gradle index a33663b..3992eef 100644 --- a/MnistAndroid/app/build.gradle +++ b/MnistAndroid/app/build.gradle @@ -9,12 +9,6 @@ android { targetSdkVersion 25 versionCode 1 versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" - externalNativeBuild { - cmake { - cppFlags "" - } - } } buildTypes { release { @@ -22,20 +16,9 @@ android { proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' } } - externalNativeBuild { - cmake { - path "CMakeLists.txt" - } - } } dependencies { - compile fileTree(include: ['*.jar'], dir: 'libs') - androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { - exclude group: 'com.android.support', module: 'support-annotations' - }) compile 'com.android.support:appcompat-v7:25.3.1' - compile 'com.android.support.constraint:constraint-layout:1.0.2' - testCompile 'junit:junit:4.12' - compile files('libs/libandroid_tensorflow_inference_java.jar') + compile 'org.tensorflow:tensorflow-android:1.2.0-rc0' } diff --git a/MnistAndroid/app/libs/libandroid_tensorflow_inference_java.jar b/MnistAndroid/app/libs/libandroid_tensorflow_inference_java.jar deleted file mode 100644 index 3b8d93b..0000000 Binary files a/MnistAndroid/app/libs/libandroid_tensorflow_inference_java.jar and /dev/null differ diff --git a/MnistAndroid/app/src/androidTest/java/mariannelinhares/mnistandroid/ExampleInstrumentedTest.java b/MnistAndroid/app/src/androidTest/java/mariannelinhares/mnistandroid/ExampleInstrumentedTest.java deleted file mode 100644 index fb64682..0000000 --- a/MnistAndroid/app/src/androidTest/java/mariannelinhares/mnistandroid/ExampleInstrumentedTest.java +++ /dev/null @@ -1,26 +0,0 @@ -package mariannelinhares.mnistandroid; - -import android.content.Context; -import android.support.test.InstrumentationRegistry; -import android.support.test.runner.AndroidJUnit4; - -import org.junit.Test; -import org.junit.runner.RunWith; - -import static org.junit.Assert.*; - -/** - * Instrumentation test, which will execute on an Android device. - * - * @see Testing documentation - */ -@RunWith(AndroidJUnit4.class) -public class ExampleInstrumentedTest { - @Test - public void useAppContext() throws Exception { - // Context of the app under test. - Context appContext = InstrumentationRegistry.getTargetContext(); - - assertEquals("mariannelinhares.mnistandroid", appContext.getPackageName()); - } -} diff --git a/MnistAndroid/app/src/main/assets/expert-graph.pb b/MnistAndroid/app/src/main/assets/opt_mnist_convnet-keras.pb similarity index 52% rename from MnistAndroid/app/src/main/assets/expert-graph.pb rename to MnistAndroid/app/src/main/assets/opt_mnist_convnet-keras.pb index 3533274..80c7ecd 100644 Binary files a/MnistAndroid/app/src/main/assets/expert-graph.pb and b/MnistAndroid/app/src/main/assets/opt_mnist_convnet-keras.pb differ diff --git a/MnistAndroid/app/src/main/assets/graph.pb b/MnistAndroid/app/src/main/assets/opt_mnist_convnet-tf.pb similarity index 52% rename from MnistAndroid/app/src/main/assets/graph.pb rename to MnistAndroid/app/src/main/assets/opt_mnist_convnet-tf.pb index dcd1321..279adec 100644 Binary files a/MnistAndroid/app/src/main/assets/graph.pb and b/MnistAndroid/app/src/main/assets/opt_mnist_convnet-tf.pb differ diff --git a/MnistAndroid/app/src/main/cpp/native-lib.cpp b/MnistAndroid/app/src/main/cpp/native-lib.cpp deleted file mode 100644 index cbb9c07..0000000 --- a/MnistAndroid/app/src/main/cpp/native-lib.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include -#include - -extern "C" -JNIEXPORT jstring JNICALL -Java_mariannelinhares_mnistandroid_MainActivity_stringFromJNI( - JNIEnv* env, - jobject /* this */) { - std::string hello = "Hello from C++"; - return env->NewStringUTF(hello.c_str()); -} diff --git a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/MainActivity.java b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/MainActivity.java index eb086be..dd39deb 100644 --- a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/MainActivity.java +++ b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/MainActivity.java @@ -15,7 +15,9 @@ See the License for the specific language governing permissions and limitations under the License. - From: https://raw.githubusercontent.com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp/narr/tensorflowmnist/DrawModel.java + From: https://raw.githubusercontent + .com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp/narr/tensorflowmnist + /DrawModel.java */ import android.app.Activity; @@ -25,41 +27,32 @@ import android.view.View; import android.widget.Button; import android.widget.TextView; - -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; - +import java.util.ArrayList; +import java.util.List; +import mariannelinhares.mnistandroid.models.Classification; +import mariannelinhares.mnistandroid.models.Classifier; +import mariannelinhares.mnistandroid.models.TensorFlowClassifier; import mariannelinhares.mnistandroid.views.DrawModel; import mariannelinhares.mnistandroid.views.DrawView; /** * Changed by marianne-linhares on 21/04/17. - * https://raw.githubusercontent.com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp/narr/tensorflowmnist/DrawModel.java + * https://raw.githubusercontent.com/miyosuda/TensorFlowAndroidMNIST/master/app/src/main/java/jp + * /narr/tensorflowmnist/DrawModel.java */ public class MainActivity extends Activity implements View.OnClickListener, View.OnTouchListener { + private static final int PIXEL_WIDTH = 28; + // ui related private Button clearBtn, classBtn; private TextView resText; - - // tensorflow input and output - private static final int INPUT_SIZE = 28; - private static final String INPUT_NAME = "input"; - private static final String OUTPUT_NAME = "output"; - - private static final String MODEL_FILE = "file:///android_asset/expert-graph.pb"; - private static final String LABEL_FILE = "file:///android_asset/labels.txt"; - - private Classifier classifier; - - private Executor executor = Executors.newSingleThreadExecutor(); + private List mClassifiers = new ArrayList<>(); // views related private DrawModel drawModel; private DrawView drawView; - private static final int PIXEL_WIDTH = 28; - private PointF mTmpPiont = new PointF(); private float mLastX; @@ -71,97 +64,84 @@ protected void onCreate(Bundle savedInstanceState) { setContentView(R.layout.activity_main); //get drawing view - drawView = (DrawView)findViewById(R.id.draw); + drawView = (DrawView) findViewById(R.id.draw); drawModel = new DrawModel(PIXEL_WIDTH, PIXEL_WIDTH); drawView.setModel(drawModel); drawView.setOnTouchListener(this); //clear button - clearBtn = (Button)findViewById(R.id.btn_clear); + clearBtn = (Button) findViewById(R.id.btn_clear); clearBtn.setOnClickListener(this); //class button - classBtn = (Button)findViewById(R.id.btn_class); + classBtn = (Button) findViewById(R.id.btn_class); classBtn.setOnClickListener(this); // res text - resText = (TextView)findViewById(R.id.tfRes); + resText = (TextView) findViewById(R.id.tfRes); // tensorflow loadModel(); } + @Override + protected void onResume() { + drawView.onResume(); + super.onResume(); + } + + @Override + protected void onPause() { + drawView.onPause(); + super.onPause(); + } + private void loadModel() { - executor.execute(new Runnable() { + new Thread(new Runnable() { @Override public void run() { try { - classifier = Classifier.create(getApplicationContext().getAssets(), - MODEL_FILE, - LABEL_FILE, - INPUT_SIZE, - INPUT_NAME, - OUTPUT_NAME); + mClassifiers.add( + TensorFlowClassifier.create(getAssets(), "TensorFlow", + "opt_mnist_convnet-tf.pb", "labels.txt", PIXEL_WIDTH, + "input", "output", true)); + mClassifiers.add( + TensorFlowClassifier.create(getAssets(), "Keras", + "opt_mnist_convnet-keras.pb", "labels.txt", PIXEL_WIDTH, + "conv2d_1_input", "dense_2/Softmax", false)); } catch (final Exception e) { - throw new RuntimeException("Error initializing TensorFlow!", e); + throw new RuntimeException("Error initializing classifiers!", e); } } - }); - } - - /** - * A native method that is implemented by the 'native-lib' native library, - * which is packaged with this application. - */ - public native String stringFromJNI(); - - // Used to load the 'native-lib' library on application startup. - static { - System.loadLibrary("native-lib"); + }).start(); } - @Override - public void onClick(View view){ - - if(view.getId() == R.id.btn_clear) { + public void onClick(View view) { + if (view.getId() == R.id.btn_clear) { drawModel.clear(); drawView.reset(); drawView.invalidate(); - resText.setText("Result: "); - } - else if(view.getId() == R.id.btn_class){ - + resText.setText(""); + } else if (view.getId() == R.id.btn_class) { float pixels[] = drawView.getPixelData(); - final Classification res = classifier.recognize(pixels); - String result = "Result: "; - if (res.getLabel() == null) { - resText.setText(result + "?"); - } - else { - result += res.getLabel(); - result += "\nwith probability: " + res.getConf(); - resText.setText(result); + String text = ""; + for (Classifier classifier : mClassifiers) { + final Classification res = classifier.recognize(pixels); + if (res.getLabel() == null) { + text += classifier.name() + ": ?\n"; + } else { + text += String.format("%s: %s, %f\n", classifier.name(), res.getLabel(), + res.getConf()); + } } + resText.setText(text); } } - @Override - protected void onResume() { - drawView.onResume(); - super.onResume(); - } - - @Override - protected void onPause() { - drawView.onPause(); - super.onPause(); - } - - @Override public boolean onTouch(View v, MotionEvent event) { int action = event.getAction() & MotionEvent.ACTION_MASK; @@ -169,11 +149,9 @@ public boolean onTouch(View v, MotionEvent event) { if (action == MotionEvent.ACTION_DOWN) { processTouchDown(event); return true; - } else if (action == MotionEvent.ACTION_MOVE) { processTouchMove(event); return true; - } else if (action == MotionEvent.ACTION_UP) { processTouchUp(); return true; diff --git a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classification.java b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classification.java similarity index 58% rename from MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classification.java rename to MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classification.java index 95bfb1f..5ba6c0b 100644 --- a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classification.java +++ b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classification.java @@ -1,4 +1,4 @@ -package mariannelinhares.mnistandroid; +package mariannelinhares.mnistandroid.models; /** * Created by marianne-linhares on 20/04/17. @@ -9,16 +9,12 @@ public class Classification { private float conf; private String label; - public Classification(float conf, String label) { - update(conf, label); - } - - public Classification() { - this.conf = (float)-1.0; + Classification() { + this.conf = -1.0F; this.label = null; } - public void update(float conf, String label) { + void update(float conf, String label) { this.conf = conf; this.label = label; } @@ -30,5 +26,4 @@ public String getLabel() { public float getConf() { return conf; } - } diff --git a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classifier.java b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classifier.java new file mode 100644 index 0000000..fd551d1 --- /dev/null +++ b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/Classifier.java @@ -0,0 +1,11 @@ +package mariannelinhares.mnistandroid.models; + +/** + * Created by Piasy{github.com/Piasy} on 29/05/2017. + */ + +public interface Classifier { + String name(); + + Classification recognize(final float[] pixels); +} diff --git a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classifier.java b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/TensorFlowClassifier.java similarity index 55% rename from MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classifier.java rename to MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/TensorFlowClassifier.java index f9b34ed..a908f67 100644 --- a/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/Classifier.java +++ b/MnistAndroid/app/src/main/java/mariannelinhares/mnistandroid/models/TensorFlowClassifier.java @@ -1,38 +1,38 @@ -package mariannelinhares.mnistandroid; +package mariannelinhares.mnistandroid.models; import android.content.res.AssetManager; - import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.List; - import org.tensorflow.contrib.android.TensorFlowInferenceInterface; /** - * Changed from https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/blob/master/app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java + * Changed from https://github.com/MindorksOpenSource/AndroidTensorFlowMNISTExample/blob/master + * /app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java * Created by marianne-linhares on 20/04/17. */ -public class Classifier { +public class TensorFlowClassifier implements Classifier { // Only returns if at least this confidence private static final float THRESHOLD = 0.1f; private TensorFlowInferenceInterface tfHelper; + private String name; private String inputName; private String outputName; private int inputSize; + private boolean feedKeepProb; private List labels; private float[] output; private String[] outputNames; - static private List readLabels(Classifier c, AssetManager am, String fileName) throws IOException { - BufferedReader br = null; - br = new BufferedReader(new InputStreamReader(am.open(fileName))); + private static List readLabels(AssetManager am, String fileName) throws IOException { + BufferedReader br = new BufferedReader(new InputStreamReader(am.open(fileName))); String line; List labels = new ArrayList<>(); @@ -44,44 +44,49 @@ static private List readLabels(Classifier c, AssetManager am, String fil return labels; } + public static TensorFlowClassifier create(AssetManager assetManager, String name, + String modelPath, String labelFile, int inputSize, String inputName, String outputName, + boolean feedKeepProb) throws IOException { + TensorFlowClassifier c = new TensorFlowClassifier(); - static public Classifier create(AssetManager assetManager, String modelPath, String labelPath, - int inputSize, String inputName, String outputName) - throws IOException { - - Classifier c = new Classifier(); + c.name = name; c.inputName = inputName; c.outputName = outputName; - // Read labels - String labelFile = labelPath.split("file:///android_asset/")[1]; - c.labels = readLabels(c, assetManager, labelFile); - - c.tfHelper = new TensorFlowInferenceInterface(); - if (c.tfHelper.initializeTensorFlow(assetManager, modelPath) != 0) { - throw new RuntimeException("TF initialization failed"); - } + c.labels = readLabels(assetManager, labelFile); + c.tfHelper = new TensorFlowInferenceInterface(assetManager, modelPath); int numClasses = 10; c.inputSize = inputSize; // Pre-allocate buffer. - c.outputNames = new String[]{ outputName }; + c.outputNames = new String[] { outputName }; c.outputName = outputName; c.output = new float[numClasses]; + c.feedKeepProb = feedKeepProb; + return c; } + @Override + public String name() { + return name; + } + + @Override public Classification recognize(final float[] pixels) { - tfHelper.fillNodeFloat(inputName, new int[]{inputSize * inputSize}, pixels); - tfHelper.runInference(outputNames); + tfHelper.feed(inputName, pixels, 1, inputSize, inputSize, 1); + if (feedKeepProb) { + tfHelper.feed("keep_prob", new float[] { 1 }); + } + tfHelper.run(outputNames); - tfHelper.readNodeFloat(outputName, output); + tfHelper.fetch(outputName, output); // Find the best classification Classification ans = new Classification(); diff --git a/MnistAndroid/app/src/main/jniLibs/armeabi-v7a/libtensorflow_inference.so b/MnistAndroid/app/src/main/jniLibs/armeabi-v7a/libtensorflow_inference.so deleted file mode 100644 index 9390465..0000000 Binary files a/MnistAndroid/app/src/main/jniLibs/armeabi-v7a/libtensorflow_inference.so and /dev/null differ diff --git a/MnistAndroid/app/src/main/jniLibs/x86/libtensorflow_mnist.so b/MnistAndroid/app/src/main/jniLibs/x86/libtensorflow_mnist.so deleted file mode 100644 index d4572f3..0000000 Binary files a/MnistAndroid/app/src/main/jniLibs/x86/libtensorflow_mnist.so and /dev/null differ diff --git a/MnistAndroid/app/src/main/res/layout/activity_main.xml b/MnistAndroid/app/src/main/res/layout/activity_main.xml index e07fe80..b3f727e 100644 --- a/MnistAndroid/app/src/main/res/layout/activity_main.xml +++ b/MnistAndroid/app/src/main/res/layout/activity_main.xml @@ -1,47 +1,50 @@ + xmlns:android="http://schemas.android.com/apk/res/android" + xmlns:tools="http://schemas.android.com/tools" + android:layout_width="match_parent" + android:layout_height="match_parent" + android:orientation="vertical" + android:paddingBottom="@dimen/activity_vertical_margin" + android:paddingLeft="@dimen/activity_horizontal_margin" + android:paddingRight="@dimen/activity_horizontal_margin" + android:paddingTop="@dimen/activity_vertical_margin" + tools:context="mariannelinhares.mnistandroid.MainActivity" + > + android:id="@+id/draw" + android:layout_width="match_parent" + android:layout_height="0dp" + android:layout_weight="1" + /> + android:layout_width="match_parent" + android:layout_height="wrap_content" + android:orientation="horizontal" + >