Skip to content

Commit 8868d8d

Browse files
authored
feat(ml): Adding Firebase ML support for AutoML models (#489)
Added support for AutoML models RELEASE NOTES: Added support for creating, updating, getting, listing, publishing, unpublishing, and deleting Firebase-hosted custom ML models created with AutoML.
1 parent 9acaff9 commit 8868d8d

File tree

4 files changed

+213
-40
lines changed

4 files changed

+213
-40
lines changed

firebase_admin/ml.py

+59-14
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@
5353
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,32}$')
5454
_GCS_TFLITE_URI_PATTERN = re.compile(
5555
r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$')
56+
_AUTO_ML_MODEL_PATTERN = re.compile(
57+
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/locations/(?P<location_id>[^/]+)/' +
58+
r'models/(?P<model_id>[A-Za-z0-9]+)$')
5659
_RESOURCE_NAME_PATTERN = re.compile(
5760
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
5861
_OPERATION_NAME_PATTERN = re.compile(
@@ -75,7 +78,7 @@ def _get_ml_service(app):
7578

7679

7780
def create_model(model, app=None):
78-
"""Creates a model in Firebase ML.
81+
"""Creates a model in the current Firebase project.
7982
8083
Args:
8184
model: An ml.Model to create.
@@ -89,7 +92,7 @@ def create_model(model, app=None):
8992

9093

9194
def update_model(model, app=None):
92-
"""Updates a model in Firebase ML.
95+
"""Updates a model's metadata or model file.
9396
9497
Args:
9598
model: The ml.Model to update.
@@ -103,7 +106,9 @@ def update_model(model, app=None):
103106

104107

105108
def publish_model(model_id, app=None):
106-
"""Publishes a model in Firebase ML.
109+
"""Publishes a Firebase ML model.
110+
111+
A published model can be downloaded to client apps.
107112
108113
Args:
109114
model_id: The id of the model to publish.
@@ -117,7 +122,7 @@ def publish_model(model_id, app=None):
117122

118123

119124
def unpublish_model(model_id, app=None):
120-
"""Unpublishes a model in Firebase ML.
125+
"""Unpublishes a Firebase ML model.
121126
122127
Args:
123128
model_id: The id of the model to unpublish.
@@ -131,7 +136,7 @@ def unpublish_model(model_id, app=None):
131136

132137

133138
def get_model(model_id, app=None):
134-
"""Gets a model from Firebase ML.
139+
"""Gets the model specified by the given ID.
135140
136141
Args:
137142
model_id: The id of the model to get.
@@ -145,7 +150,7 @@ def get_model(model_id, app=None):
145150

146151

147152
def list_models(list_filter=None, page_size=None, page_token=None, app=None):
148-
"""Lists models from Firebase ML.
153+
"""Lists the current project's models.
149154
150155
Args:
151156
list_filter: a list filter string such as ``tags:'tag_1'``. None will return all models.
@@ -164,7 +169,7 @@ def list_models(list_filter=None, page_size=None, page_token=None, app=None):
164169

165170

166171
def delete_model(model_id, app=None):
167-
"""Deletes a model from Firebase ML.
172+
"""Deletes a model from the current project.
168173
169174
Args:
170175
model_id: The id of the model you wish to delete.
@@ -363,15 +368,10 @@ def __init__(self, model_source=None):
363368
def from_dict(cls, data):
364369
"""Create an instance of the object from a dict."""
365370
data_copy = dict(data)
366-
model_source = None
367-
gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None)
368-
if gcs_tflite_uri:
369-
model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
370-
tflite_format = TFLiteFormat(model_source=model_source)
371+
tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy))
371372
tflite_format._data = data_copy # pylint: disable=protected-access
372373
return tflite_format
373374

374-
375375
def __eq__(self, other):
376376
if isinstance(other, self.__class__):
377377
# pylint: disable=protected-access
@@ -381,6 +381,16 @@ def __eq__(self, other):
381381
def __ne__(self, other):
382382
return not self.__eq__(other)
383383

384+
@staticmethod
385+
def _init_model_source(data):
386+
gcs_tflite_uri = data.pop('gcsTfliteUri', None)
387+
if gcs_tflite_uri:
388+
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
389+
auto_ml_model = data.pop('automlModel', None)
390+
if auto_ml_model:
391+
return TFLiteAutoMlSource(auto_ml_model=auto_ml_model)
392+
return None
393+
384394
@property
385395
def model_source(self):
386396
"""The TF Lite model's location."""
@@ -593,8 +603,38 @@ def as_dict(self, for_upload=False):
593603
return {'gcsTfliteUri': self._gcs_tflite_uri}
594604

595605

606+
class TFLiteAutoMlSource(TFLiteModelSource):
607+
"""TFLite model source representing a tflite model created with AutoML."""
608+
609+
def __init__(self, auto_ml_model, app=None):
610+
self._app = app
611+
self.auto_ml_model = auto_ml_model
612+
613+
def __eq__(self, other):
614+
if isinstance(other, self.__class__):
615+
return self.auto_ml_model == other.auto_ml_model
616+
return False
617+
618+
def __ne__(self, other):
619+
return not self.__eq__(other)
620+
621+
@property
622+
def auto_ml_model(self):
623+
"""Resource name of the model, created by the AutoML API or Cloud console."""
624+
return self._auto_ml_model
625+
626+
@auto_ml_model.setter
627+
def auto_ml_model(self, auto_ml_model):
628+
self._auto_ml_model = _validate_auto_ml_model(auto_ml_model)
629+
630+
def as_dict(self, for_upload=False):
631+
"""Returns a serializable representation of the object."""
632+
# Upload is irrelevant for auto_ml models
633+
return {'automlModel': self._auto_ml_model}
634+
635+
596636
class ListModelsPage:
597-
"""Represents a page of models in a firebase project.
637+
"""Represents a page of models in a Firebase project.
598638
599639
Provides methods for traversing the models included in this page, as well as
600640
retrieving subsequent pages of models. The iterator returned by
@@ -740,6 +780,11 @@ def _validate_gcs_tflite_uri(uri):
740780
raise ValueError('GCS TFLite URI format is invalid.')
741781
return uri
742782

783+
def _validate_auto_ml_model(model):
784+
if not _AUTO_ML_MODEL_PATTERN.match(model):
785+
raise ValueError('Model resource name format is invalid.')
786+
return model
787+
743788

744789
def _validate_model_format(model_format):
745790
if not isinstance(model_format, ModelFormat):

integration/test_ml.py

+88-24
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import pytest
2424

25+
import firebase_admin
2526
from firebase_admin import exceptions
2627
from firebase_admin import ml
2728
from tests import testutils
@@ -34,6 +35,11 @@
3435
except ImportError:
3536
_TF_ENABLED = False
3637

38+
try:
39+
from google.cloud import automl_v1
40+
_AUTOML_ENABLED = True
41+
except ImportError:
42+
_AUTOML_ENABLED = False
3743

3844
def _random_identifier(prefix):
3945
#pylint: disable=unused-variable
@@ -62,7 +68,6 @@ def _random_identifier(prefix):
6268
'file_name': 'invalid_model.tflite'
6369
}
6470

65-
6671
@pytest.fixture
6772
def firebase_model(request):
6873
args = request.param
@@ -101,6 +106,7 @@ def _clean_up_model(model):
101106
try:
102107
# Try to delete the model.
103108
# Some tests delete the model as part of the test.
109+
model.wait_for_unlocked()
104110
ml.delete_model(model.model_id)
105111
except exceptions.NotFoundError:
106112
pass
@@ -132,35 +138,45 @@ def check_model(model, args):
132138
assert model.locked is False
133139
assert model.etag is not None
134140

141+
# Model Format Checks
135142

136-
def check_model_format(model, has_model_format=False, validation_error=None):
137-
if has_model_format:
138-
assert model.validation_error == validation_error
139-
assert model.published is False
140-
assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://')
141-
if validation_error:
142-
assert model.model_format.size_bytes is None
143-
assert model.model_hash is None
144-
else:
145-
assert model.model_format.size_bytes is not None
146-
assert model.model_hash is not None
147-
else:
148-
assert model.model_format is None
149-
assert model.validation_error == 'No model file has been uploaded.'
150-
assert model.published is False
143+
def check_no_model_format(model):
144+
assert model.model_format is None
145+
assert model.validation_error == 'No model file has been uploaded.'
146+
assert model.published is False
147+
assert model.model_hash is None
148+
149+
150+
def check_tflite_gcs_format(model, validation_error=None):
151+
assert model.validation_error == validation_error
152+
assert model.published is False
153+
assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://')
154+
if validation_error:
155+
assert model.model_format.size_bytes is None
151156
assert model.model_hash is None
157+
else:
158+
assert model.model_format.size_bytes is not None
159+
assert model.model_hash is not None
160+
161+
162+
def check_tflite_automl_format(model):
163+
assert model.validation_error is None
164+
assert model.published is False
165+
assert model.model_format.model_source.auto_ml_model.startswith('projects/')
166+
# Automl models don't have validation errors since they are references
167+
# to valid automl models.
152168

153169

154170
@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
155171
def test_create_simple_model(firebase_model):
156172
check_model(firebase_model, NAME_AND_TAGS_ARGS)
157-
check_model_format(firebase_model)
173+
check_no_model_format(firebase_model)
158174

159175

160176
@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
161177
def test_create_full_model(firebase_model):
162178
check_model(firebase_model, FULL_MODEL_ARGS)
163-
check_model_format(firebase_model, True)
179+
check_tflite_gcs_format(firebase_model)
164180

165181

166182
@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True)
@@ -175,14 +191,14 @@ def test_create_already_existing_fails(firebase_model):
175191
@pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True)
176192
def test_create_invalid_model(firebase_model):
177193
check_model(firebase_model, INVALID_FULL_MODEL_ARGS)
178-
check_model_format(firebase_model, True, 'Invalid flatbuffer format')
194+
check_tflite_gcs_format(firebase_model, 'Invalid flatbuffer format')
179195

180196

181197
@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
182198
def test_get_model(firebase_model):
183199
get_model = ml.get_model(firebase_model.model_id)
184200
check_model(get_model, NAME_AND_TAGS_ARGS)
185-
check_model_format(get_model)
201+
check_no_model_format(get_model)
186202

187203

188204
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
@@ -201,12 +217,12 @@ def test_update_model(firebase_model):
201217
firebase_model.display_name = new_model_name
202218
updated_model = ml.update_model(firebase_model)
203219
check_model(updated_model, NAME_ONLY_ARGS_UPDATED)
204-
check_model_format(updated_model)
220+
check_no_model_format(updated_model)
205221

206222
# Second call with same model does not cause error
207223
updated_model2 = ml.update_model(updated_model)
208224
check_model(updated_model2, NAME_ONLY_ARGS_UPDATED)
209-
check_model_format(updated_model2)
225+
check_no_model_format(updated_model2)
210226

211227

212228
@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True)
@@ -290,7 +306,7 @@ def test_delete_model(firebase_model):
290306

291307
# Test tensor flow conversion functions if tensor flow is enabled.
292308
#'pip install tensorflow' in the environment if you want _TF_ENABLED = True
293-
#'pip install tensorflow==2.0.0b' for version 2 etc.
309+
#'pip install tensorflow==2.2.0' for version 2.2.0 etc.
294310

295311

296312
def _clean_up_directory(save_dir):
@@ -334,6 +350,7 @@ def saved_model_dir(keras_model):
334350
_clean_up_directory(parent)
335351

336352

353+
337354
@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.')
338355
def test_from_keras_model(keras_model):
339356
source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite')
@@ -348,7 +365,7 @@ def test_from_keras_model(keras_model):
348365

349366
try:
350367
check_model(created_model, {'display_name': model.display_name})
351-
check_model_format(created_model, True)
368+
check_tflite_gcs_format(created_model)
352369
finally:
353370
_clean_up_model(created_model)
354371

@@ -371,3 +388,50 @@ def test_from_saved_model(saved_model_dir):
371388
assert created_model.validation_error is None
372389
finally:
373390
_clean_up_model(created_model)
391+
392+
393+
# Test AutoML functionality if AutoML is enabled.
394+
#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True
395+
# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the
396+
# successful test. (Test is skipped otherwise)
397+
398+
@pytest.fixture
399+
def automl_model():
400+
assert _AUTOML_ENABLED
401+
402+
# It takes > 20 minutes to train a model, so we expect a predefined AutoMl
403+
# model named 'admin_sdk_integ_test1' to exist in the project, or we skip
404+
# the test.
405+
automl_client = automl_v1.AutoMlClient()
406+
project_id = firebase_admin.get_app().project_id
407+
parent = automl_client.location_path(project_id, 'us-central1')
408+
models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1")
409+
# Expecting exactly one. (Ok to use last one if somehow more than 1)
410+
automl_ref = None
411+
for model in models:
412+
automl_ref = model.name
413+
414+
# Skip if no pre-defined model. (It takes min > 20 minutes to train a model)
415+
if automl_ref is None:
416+
pytest.skip("No pre-existing AutoML model found. Skipping test")
417+
418+
source = ml.TFLiteAutoMlSource(automl_ref)
419+
tflite_format = ml.TFLiteFormat(model_source=source)
420+
ml_model = ml.Model(
421+
display_name=_random_identifier('TestModel_automl_'),
422+
tags=['test_automl'],
423+
model_format=tflite_format)
424+
model = ml.create_model(model=ml_model)
425+
yield model
426+
_clean_up_model(model)
427+
428+
@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.')
429+
def test_automl_model(automl_model):
430+
# This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1'
431+
automl_model.wait_for_unlocked()
432+
433+
check_model(automl_model, {
434+
'display_name': automl_model.display_name,
435+
'tags': ['test_automl'],
436+
})
437+
check_tflite_automl_format(automl_model)

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ pytest-localserver >= 0.4.1
77
cachecontrol >= 0.12.6
88
google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != 'PyPy'
99
google-api-python-client >= 1.7.8
10+
google-auth == 1.18.0 # temporary workaround
1011
google-cloud-firestore >= 1.4.0; platform.python_implementation != 'PyPy'
1112
google-cloud-storage >= 1.18.0

0 commit comments

Comments
 (0)