22
22
23
23
import pytest
24
24
25
+ import firebase_admin
25
26
from firebase_admin import exceptions
26
27
from firebase_admin import ml
27
28
from tests import testutils
34
35
except ImportError :
35
36
_TF_ENABLED = False
36
37
38
+ try :
39
+ from google .cloud import automl_v1
40
+ _AUTOML_ENABLED = True
41
+ except ImportError :
42
+ _AUTOML_ENABLED = False
37
43
38
44
def _random_identifier (prefix ):
39
45
#pylint: disable=unused-variable
@@ -62,7 +68,6 @@ def _random_identifier(prefix):
62
68
'file_name' : 'invalid_model.tflite'
63
69
}
64
70
65
-
66
71
@pytest .fixture
67
72
def firebase_model (request ):
68
73
args = request .param
@@ -101,6 +106,7 @@ def _clean_up_model(model):
101
106
try :
102
107
# Try to delete the model.
103
108
# Some tests delete the model as part of the test.
109
+ model .wait_for_unlocked ()
104
110
ml .delete_model (model .model_id )
105
111
except exceptions .NotFoundError :
106
112
pass
@@ -132,35 +138,45 @@ def check_model(model, args):
132
138
assert model .locked is False
133
139
assert model .etag is not None
134
140
141
+ # Model Format Checks
135
142
136
- def check_model_format (model , has_model_format = False , validation_error = None ):
137
- if has_model_format :
138
- assert model .validation_error == validation_error
139
- assert model .published is False
140
- assert model .model_format .model_source .gcs_tflite_uri .startswith ('gs://' )
141
- if validation_error :
142
- assert model .model_format .size_bytes is None
143
- assert model .model_hash is None
144
- else :
145
- assert model .model_format .size_bytes is not None
146
- assert model .model_hash is not None
147
- else :
148
- assert model .model_format is None
149
- assert model .validation_error == 'No model file has been uploaded.'
150
- assert model .published is False
143
+ def check_no_model_format (model ):
144
+ assert model .model_format is None
145
+ assert model .validation_error == 'No model file has been uploaded.'
146
+ assert model .published is False
147
+ assert model .model_hash is None
148
+
149
+
150
+ def check_tflite_gcs_format (model , validation_error = None ):
151
+ assert model .validation_error == validation_error
152
+ assert model .published is False
153
+ assert model .model_format .model_source .gcs_tflite_uri .startswith ('gs://' )
154
+ if validation_error :
155
+ assert model .model_format .size_bytes is None
151
156
assert model .model_hash is None
157
+ else :
158
+ assert model .model_format .size_bytes is not None
159
+ assert model .model_hash is not None
160
+
161
+
162
+ def check_tflite_automl_format (model ):
163
+ assert model .validation_error is None
164
+ assert model .published is False
165
+ assert model .model_format .model_source .auto_ml_model .startswith ('projects/' )
166
+ # Automl models don't have validation errors since they are references
167
+ # to valid automl models.
152
168
153
169
154
170
@pytest .mark .parametrize ('firebase_model' , [NAME_AND_TAGS_ARGS ], indirect = True )
155
171
def test_create_simple_model (firebase_model ):
156
172
check_model (firebase_model , NAME_AND_TAGS_ARGS )
157
- check_model_format (firebase_model )
173
+ check_no_model_format (firebase_model )
158
174
159
175
160
176
@pytest .mark .parametrize ('firebase_model' , [FULL_MODEL_ARGS ], indirect = True )
161
177
def test_create_full_model (firebase_model ):
162
178
check_model (firebase_model , FULL_MODEL_ARGS )
163
- check_model_format (firebase_model , True )
179
+ check_tflite_gcs_format (firebase_model )
164
180
165
181
166
182
@pytest .mark .parametrize ('firebase_model' , [FULL_MODEL_ARGS ], indirect = True )
@@ -175,14 +191,14 @@ def test_create_already_existing_fails(firebase_model):
175
191
@pytest .mark .parametrize ('firebase_model' , [INVALID_FULL_MODEL_ARGS ], indirect = True )
176
192
def test_create_invalid_model (firebase_model ):
177
193
check_model (firebase_model , INVALID_FULL_MODEL_ARGS )
178
- check_model_format (firebase_model , True , 'Invalid flatbuffer format' )
194
+ check_tflite_gcs_format (firebase_model , 'Invalid flatbuffer format' )
179
195
180
196
181
197
@pytest .mark .parametrize ('firebase_model' , [NAME_AND_TAGS_ARGS ], indirect = True )
182
198
def test_get_model (firebase_model ):
183
199
get_model = ml .get_model (firebase_model .model_id )
184
200
check_model (get_model , NAME_AND_TAGS_ARGS )
185
- check_model_format (get_model )
201
+ check_no_model_format (get_model )
186
202
187
203
188
204
@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
@@ -201,12 +217,12 @@ def test_update_model(firebase_model):
201
217
firebase_model .display_name = new_model_name
202
218
updated_model = ml .update_model (firebase_model )
203
219
check_model (updated_model , NAME_ONLY_ARGS_UPDATED )
204
- check_model_format (updated_model )
220
+ check_no_model_format (updated_model )
205
221
206
222
# Second call with same model does not cause error
207
223
updated_model2 = ml .update_model (updated_model )
208
224
check_model (updated_model2 , NAME_ONLY_ARGS_UPDATED )
209
- check_model_format (updated_model2 )
225
+ check_no_model_format (updated_model2 )
210
226
211
227
212
228
@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
@@ -290,7 +306,7 @@ def test_delete_model(firebase_model):
290
306
291
307
# Test tensor flow conversion functions if tensor flow is enabled.
292
308
#'pip install tensorflow' in the environment if you want _TF_ENABLED = True
293
- #'pip install tensorflow==2.0.0b ' for version 2 etc.
309
+ #'pip install tensorflow==2.2.0 ' for version 2.2.0 etc.
294
310
295
311
296
312
def _clean_up_directory (save_dir ):
@@ -334,6 +350,7 @@ def saved_model_dir(keras_model):
334
350
_clean_up_directory (parent )
335
351
336
352
353
+
337
354
@pytest .mark .skipif (not _TF_ENABLED , reason = 'Tensor flow is required for this test.' )
338
355
def test_from_keras_model (keras_model ):
339
356
source = ml .TFLiteGCSModelSource .from_keras_model (keras_model , 'model2.tflite' )
@@ -348,7 +365,7 @@ def test_from_keras_model(keras_model):
348
365
349
366
try :
350
367
check_model (created_model , {'display_name' : model .display_name })
351
- check_model_format (created_model , True )
368
+ check_tflite_gcs_format (created_model )
352
369
finally :
353
370
_clean_up_model (created_model )
354
371
@@ -371,3 +388,50 @@ def test_from_saved_model(saved_model_dir):
371
388
assert created_model .validation_error is None
372
389
finally :
373
390
_clean_up_model (created_model )
391
+
392
+
393
+ # Test AutoML functionality if AutoML is enabled.
394
+ #'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True
395
+ # You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the
396
+ # successful test. (Test is skipped otherwise)
397
+
398
+ @pytest .fixture
399
+ def automl_model ():
400
+ assert _AUTOML_ENABLED
401
+
402
+ # It takes > 20 minutes to train a model, so we expect a predefined AutoMl
403
+ # model named 'admin_sdk_integ_test1' to exist in the project, or we skip
404
+ # the test.
405
+ automl_client = automl_v1 .AutoMlClient ()
406
+ project_id = firebase_admin .get_app ().project_id
407
+ parent = automl_client .location_path (project_id , 'us-central1' )
408
+ models = automl_client .list_models (parent , filter_ = "display_name=admin_sdk_integ_test1" )
409
+ # Expecting exactly one. (Ok to use last one if somehow more than 1)
410
+ automl_ref = None
411
+ for model in models :
412
+ automl_ref = model .name
413
+
414
+ # Skip if no pre-defined model. (It takes min > 20 minutes to train a model)
415
+ if automl_ref is None :
416
+ pytest .skip ("No pre-existing AutoML model found. Skipping test" )
417
+
418
+ source = ml .TFLiteAutoMlSource (automl_ref )
419
+ tflite_format = ml .TFLiteFormat (model_source = source )
420
+ ml_model = ml .Model (
421
+ display_name = _random_identifier ('TestModel_automl_' ),
422
+ tags = ['test_automl' ],
423
+ model_format = tflite_format )
424
+ model = ml .create_model (model = ml_model )
425
+ yield model
426
+ _clean_up_model (model )
427
+
428
+ @pytest .mark .skipif (not _AUTOML_ENABLED , reason = 'AutoML is required for this test.' )
429
+ def test_automl_model (automl_model ):
430
+ # This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1'
431
+ automl_model .wait_for_unlocked ()
432
+
433
+ check_model (automl_model , {
434
+ 'display_name' : automl_model .display_name ,
435
+ 'tags' : ['test_automl' ],
436
+ })
437
+ check_tflite_automl_format (automl_model )
0 commit comments