18
18
19
19
import android .content .res .AssetManager ;
20
20
import android .graphics .Bitmap ;
21
- import android .os .Trace ;
21
+ import android .support . v4 . os .TraceCompat ;
22
22
import android .util .Log ;
23
23
24
24
import org .tensorflow .contrib .android .TensorFlowInferenceInterface ;
41
41
*/
42
42
public class TensorFlowImageClassifier implements Classifier {
43
43
44
- private static final String TAG = "TensorFlowImageClassifier " ;
44
+ private static final String TAG = "ImageClassifier " ;
45
45
46
46
// Only return this many results with at least this confidence.
47
47
private static final int MAX_RESULTS = 3 ;
@@ -63,6 +63,8 @@ public class TensorFlowImageClassifier implements Classifier {
63
63
64
64
private TensorFlowInferenceInterface inferenceInterface ;
65
65
66
+ private boolean runStats = false ;
67
+
66
68
private TensorFlowImageClassifier () {
67
69
}
68
70
@@ -105,10 +107,7 @@ public static Classifier create(
105
107
}
106
108
br .close ();
107
109
108
- c .inferenceInterface = new TensorFlowInferenceInterface ();
109
- if (c .inferenceInterface .initializeTensorFlow (assetManager , modelFilename ) != 0 ) {
110
- throw new RuntimeException ("TF initialization failed" );
111
- }
110
+ c .inferenceInterface = new TensorFlowInferenceInterface (assetManager , modelFilename );
112
111
// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
113
112
int numClasses =
114
113
(int ) c .inferenceInterface .graph ().operation (outputName ).output (0 ).shape ().size (1 );
@@ -133,9 +132,9 @@ public static Classifier create(
133
132
@ Override
134
133
public List <Recognition > recognizeImage (final Bitmap bitmap ) {
135
134
// Log this method so that it can be analyzed with systrace.
136
- Trace .beginSection ("recognizeImage" );
135
+ TraceCompat .beginSection ("recognizeImage" );
137
136
138
- Trace .beginSection ("preprocessBitmap" );
137
+ TraceCompat .beginSection ("preprocessBitmap" );
139
138
// Preprocess the image data from 0-255 int to normalized float based
140
139
// on the provided parameters.
141
140
bitmap .getPixels (intValues , 0 , bitmap .getWidth (), 0 , 0 , bitmap .getWidth (), bitmap .getHeight ());
@@ -145,23 +144,23 @@ public List<Recognition> recognizeImage(final Bitmap bitmap) {
145
144
floatValues [i * 3 + 1 ] = (((val >> 8 ) & 0xFF ) - imageMean ) / imageStd ;
146
145
floatValues [i * 3 + 2 ] = ((val & 0xFF ) - imageMean ) / imageStd ;
147
146
}
148
- Trace .endSection ();
147
+ TraceCompat .endSection ();
149
148
150
149
// Copy the input data into TensorFlow.
151
- Trace .beginSection ("fillNodeFloat " );
152
- inferenceInterface .fillNodeFloat (
153
- inputName , new int []{1 , inputSize , inputSize , 3 }, floatValues );
154
- Trace .endSection ();
150
+ TraceCompat .beginSection ("feed " );
151
+ inferenceInterface .feed (
152
+ inputName , floatValues , new long []{1 , inputSize , inputSize , 3 });
153
+ TraceCompat .endSection ();
155
154
156
155
// Run the inference call.
157
- Trace .beginSection ("runInference " );
158
- inferenceInterface .runInference (outputNames );
159
- Trace .endSection ();
156
+ TraceCompat .beginSection ("run " );
157
+ inferenceInterface .run (outputNames , runStats );
158
+ TraceCompat .endSection ();
160
159
161
160
// Copy the output Tensor back into the output array.
162
- Trace .beginSection ("readNodeFloat " );
163
- inferenceInterface .readNodeFloat (outputName , outputs );
164
- Trace .endSection ();
161
+ TraceCompat .beginSection ("fetch " );
162
+ inferenceInterface .fetch (outputName , outputs );
163
+ TraceCompat .endSection ();
165
164
166
165
// Find the best classifications.
167
166
PriorityQueue <Recognition > pq =
@@ -186,13 +185,13 @@ public int compare(Recognition lhs, Recognition rhs) {
186
185
for (int i = 0 ; i < recognitionsSize ; ++i ) {
187
186
recognitions .add (pq .poll ());
188
187
}
189
- Trace .endSection (); // "recognizeImage"
188
+ TraceCompat .endSection (); // "recognizeImage"
190
189
return recognitions ;
191
190
}
192
191
193
192
@ Override
194
193
public void enableStatLogging (boolean debug ) {
195
- inferenceInterface . enableStatLogging ( debug ) ;
194
+ runStats = debug ;
196
195
}
197
196
198
197
@ Override
0 commit comments