Skip to content

Commit 480f8f6

Browse files
ecalubaquibtensorflower-gardener
authored andcommitted
Update tensorflow_models usage of tf.lite.interpreter to run ai-edge-litert.interpreter
PiperOrigin-RevId: 681625530
1 parent c651073 commit 480f8f6

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

official/projects/edgetpu/vision/serving/tflite_imagenet_evaluator.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
import numpy as np
2222
import tensorflow as tf, tf_keras
2323

24+
# pylint: disable=g-direct-tensorflow-import
25+
from tensorflow.lite.python import interpreter as tfl_interpreter
26+
# pylint: enable=g-direct-tensorflow-import
27+
2428

2529
@dataclasses.dataclass
2630
class EvaluationInput():
@@ -58,8 +62,9 @@ def evaluate_single_image(self, eval_input: EvaluationInput) -> bool:
5862
Returns:
5963
Whether the estimation is correct.
6064
"""
61-
interpreter = tf.lite.Interpreter(
62-
model_content=self._model_content, num_threads=1)
65+
interpreter = tfl_interpreter.Interpreter(
66+
model_content=self._model_content, num_threads=1
67+
)
6368
interpreter.allocate_tensors()
6469
# Get input and output tensors and quantization details.
6570
input_details = interpreter.get_input_details()

official/projects/movinet/tools/export_saved_model_test.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# limitations under the License.
1414

1515
"""Tests for export_saved_model."""
16-
1716
from absl import flags
1817
import tensorflow as tf, tf_keras
1918
import tensorflow_hub as hub
2019

20+
# pylint: disable=g-direct-tensorflow-import
21+
from tensorflow.lite.python import interpreter as tfl_interpreter
22+
# pylint: enable=g-direct-tensorflow-import
2123
from official.projects.movinet.tools import export_saved_model
2224

25+
2326
FLAGS = flags.FLAGS
2427

2528

@@ -120,7 +123,7 @@ def test_movinet_export_a0_stream_with_tflite(self):
120123
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
121124
tflite_model = converter.convert()
122125

123-
interpreter = tf.lite.Interpreter(model_content=tflite_model)
126+
interpreter = tfl_interpreter.Interpreter(model_content=tflite_model)
124127
runner = interpreter.get_signature_runner('serving_default')
125128

126129
def state_name(name: str) -> str:

official/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ sentencepiece
2727
sacrebleu
2828
# Projects/vit dependencies
2929
immutabledict
30+
ai-edge-litert>=1.0.1

0 commit comments

Comments
 (0)