Skip to content

Commit c1780b7

Browse files
juanuribe28Tensorflow Cloud maintainers
authored and
Tensorflow Cloud maintainers
committed
Add integration tests for run_experiment_cloud wrapper.
PiperOrigin-RevId: 383893019
1 parent f1ae448 commit c1780b7

File tree

8 files changed

+427
-123
lines changed

8 files changed

+427
-123
lines changed

Diff for: src/python/dependencies.py

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def make_required_install_packages():
2727
"tensorflow>=1.15.0,<3.0",
2828
"tensorflow_datasets",
2929
"tensorflow_transform",
30+
"tf-models-official",
31+
"importlib_resources ; python_version<'3.7'"
3032
]
3133

3234

@@ -38,4 +40,5 @@ def make_required_test_packages():
3840
"numpy",
3941
"nbconvert",
4042
"tf-models-official",
43+
"importlib_resources ; python_version<'3.7'"
4144
]

Diff for: src/python/tensorflow_cloud/core/containerize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def _get_file_path_map(self):
285285
self.entry_point = sys.argv[0]
286286

287287
# Map entry_point directory to the dst directory.
288-
if not self.called_from_notebook:
288+
if not self.called_from_notebook or self.entry_point is not None:
289289
entry_point_dir, _ = os.path.split(self.entry_point)
290290
if not entry_point_dir: # Current directory
291291
entry_point_dir = "."

Diff for: src/python/tensorflow_cloud/core/experimental/models.py

+48-44
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,27 @@
1515
"""Module that contains the `run_models` wrapper for training models from TF Model Garden."""
1616

1717
import os
18+
import pickle
19+
import shutil
1820
from typing import Any, Dict, Optional
21+
import uuid
1922

2023
from .. import machine_config
2124
from .. import run
2225
import tensorflow as tf
2326
import tensorflow_datasets as tfds
2427

25-
from official.core import train_lib
2628
from official.vision.image_classification.efficientnet import efficientnet_model
2729
from official.vision.image_classification.resnet import resnet_model
2830

31+
# pylint: disable=g-import-not-at-top
32+
try:
33+
import importlib.resources as pkg_resources
34+
except ImportError:
35+
# Backported for python<3.7
36+
import importlib_resources as pkg_resources
37+
# pylint: enable=g-import-not-at-top
38+
2939

3040
def run_models(dataset_name: str,
3141
model_name: str,
@@ -251,48 +261,42 @@ def run_experiment_cloud(run_experiment_kwargs: Dict[str, Any],
251261
"""
252262
if run_kwargs is None:
253263
run_kwargs = dict()
254-
255-
if run.remote():
256-
default_machine_config = machine_config.COMMON_MACHINE_CONFIGS['T4_1X']
257-
if 'chief_config' in run_kwargs:
258-
chief_config = run_kwargs['chief_config']
259-
else:
260-
chief_config = default_machine_config
261-
if 'worker_count' in run_kwargs:
262-
worker_count = run_kwargs['worker_count']
264+
distribution_strategy = get_distribution_strategy_str(run_kwargs)
265+
run_experiment_kwargs.update(
266+
dict(distribution_strategy=distribution_strategy))
267+
file_id = str(uuid.uuid4())
268+
params_file = save_params(run_experiment_kwargs, file_id)
269+
270+
with pkg_resources.path(__package__, 'models_entry_point.py') as path:
271+
entry_point = f'{file_id}.py'
272+
shutil.copyfile(str(path), entry_point)
273+
run_kwargs.update(dict(entry_point=entry_point,
274+
distribution_strategy=None))
275+
info = run.run(**run_kwargs)
276+
os.remove(entry_point)
277+
os.remove(params_file)
278+
return info
279+
280+
281+
def get_distribution_strategy_str(run_kwargs):
282+
"""Gets the name of a distribution strategy based on cloud run config."""
283+
if ('worker_count' in run_kwargs
284+
and run_kwargs['worker_count'] > 0):
285+
if ('worker_config' in run_kwargs
286+
and machine_config.is_tpu_config(run_kwargs['worker_config'])):
287+
return 'tpu'
263288
else:
264-
worker_count = 0
265-
if 'worker_config' in run_kwargs:
266-
worker_config = run_kwargs['worker_config']
267-
else:
268-
worker_config = default_machine_config
269-
distribution_strategy = get_distribution_strategy(chief_config,
270-
worker_count,
271-
worker_config)
272-
run_experiment_kwargs.update(
273-
dict(distribution_strategy=distribution_strategy))
274-
model, _ = train_lib.run_experiment(**run_experiment_kwargs)
275-
model.save(run_experiment_kwargs['model_dir'])
276-
277-
run_kwargs.update(dict(entry_point=None,
278-
distribution_strategy=None))
279-
return run.run(**run_kwargs)
280-
281-
282-
def get_distribution_strategy(chief_config, worker_count, worker_config):
283-
"""Gets a tf distribution strategy based on the cloud run config."""
284-
if worker_count > 0:
285-
if machine_config.is_tpu_config(worker_config):
286-
# TODO(b/194857231) Dependency conflict for using TPUs
287-
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
288-
tpu='local')
289-
tf.config.experimental_connect_to_cluster(resolver)
290-
tf.tpu.experimental.initialize_tpu_system(resolver)
291-
return tf.distribute.TPUStrategy(resolver)
292-
else:
293-
# TODO(b/148619319) Saving model currently failing
294-
return tf.distribute.MultiWorkerMirroredStrategy()
295-
elif chief_config.accelerator_count > 1:
296-
return tf.distribute.MirroredStrategy()
289+
return 'multi_mirror'
290+
elif ('chief_config' in run_kwargs
291+
and run_kwargs['chief_config'].accelerator_count > 1):
292+
return 'mirror'
297293
else:
298-
return tf.distribute.OneDeviceStrategy(device='/gpu:0')
294+
return 'one_device'
295+
296+
297+
def save_params(params, file_id):
298+
"""Pickles the params object using the file_id as prefix."""
299+
file_name = f'{file_id}_params'
300+
with open(file_name, 'xb') as f:
301+
pickle.dump(params, f)
302+
return file_name
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Lint as: python3
2+
# Copyright 2021 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Entry point file for run_experiment_cloud."""
16+
17+
import os
18+
import pickle
19+
20+
import tensorflow as tf
21+
22+
from official.core import train_lib
23+
24+
25+
def load_params(file_name):
26+
with open(file_name, 'rb') as f:
27+
params = pickle.load(f)
28+
return params
29+
30+
31+
def get_tpu_strategy():
32+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
33+
tpu='local')
34+
tf.config.experimental_connect_to_cluster(resolver)
35+
tf.tpu.experimental.initialize_tpu_system(resolver)
36+
return tf.distribute.TPUStrategy(resolver)
37+
38+
39+
def get_one_device():
40+
return tf.distribute.OneDeviceStrategy(device='/gpu:0')
41+
42+
_DISTRIBUTION_STRATEGIES = dict(
43+
# TODO(b/194857231) Dependency conflict for using TPUs
44+
tpu=get_tpu_strategy,
45+
# TODO(b/148619319) Saving model currently failing for multi_mirror
46+
multi_mirror=tf.distribute.MultiWorkerMirroredStrategy,
47+
mirror=tf.distribute.MirroredStrategy,
48+
one_device=get_one_device)
49+
50+
51+
def main():
52+
prefix, _ = os.path.splitext(os.path.basename(__file__))
53+
run_experiment_kwargs = load_params(f'{prefix}_params')
54+
strategy_str = run_experiment_kwargs['distribution_strategy']
55+
strategy = _DISTRIBUTION_STRATEGIES[strategy_str]()
56+
run_experiment_kwargs.update(dict(
57+
distribution_strategy=strategy))
58+
model, _ = train_lib.run_experiment(**run_experiment_kwargs)
59+
model.save(run_experiment_kwargs['model_dir'])
60+
61+
62+
if __name__ == '__main__':
63+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Lint as: python3
2+
# Copyright 2021 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Integration tests for calling run_experiment_cloud."""
16+
17+
import os
18+
19+
import tensorflow as tf
20+
import tensorflow_cloud as tfc
21+
from tensorflow_cloud.core.experimental import models
22+
from tensorflow_cloud.utils import google_api_client
23+
from official.core import task_factory
24+
from official.utils.testing import mock_task
25+
26+
# The staging bucket to use for cloudbuild as well as save the model and data.
27+
_TEST_BUCKET = os.environ["TEST_BUCKET"]
28+
_PROJECT_ID = os.environ["PROJECT_ID"]
29+
_PARENT_IMAGE = "gcr.io/deeplearning-platform-release/tf2-gpu.2-5"
30+
_BASE_PATH = f"gs://{_TEST_BUCKET}"
31+
32+
33+
class RunExperimentCloudTest(tf.test.TestCase):
34+
35+
def setUp(self):
36+
super(RunExperimentCloudTest, self).setUp()
37+
self.test_data_path = os.path.join(
38+
os.path.dirname(os.path.abspath(__file__)), "../testdata/"
39+
)
40+
self.requirements_txt = os.path.join(self.test_data_path,
41+
"requirements.txt")
42+
43+
self._test_config = {
44+
"trainer": {
45+
"checkpoint_interval": 10,
46+
"steps_per_loop": 10,
47+
"summary_interval": 10,
48+
"train_steps": 10,
49+
"validation_steps": 5,
50+
"validation_interval": 10,
51+
"continuous_eval_timeout": 1,
52+
"validation_summary_subdir": "validation",
53+
"optimizer_config": {
54+
"optimizer": {
55+
"type": "sgd",
56+
},
57+
"learning_rate": {
58+
"type": "constant"
59+
}
60+
}
61+
},
62+
}
63+
64+
self.params = mock_task.mock_experiment()
65+
self.params.override(self._test_config, is_strict=False)
66+
self.run_experiment_kwargs = dict(
67+
params=self.params,
68+
task=task_factory.get_task(self.params.task),
69+
mode="train_and_eval",
70+
)
71+
self.docker_config = tfc.DockerConfig(
72+
parent_image=_PARENT_IMAGE,
73+
image_build_bucket=_TEST_BUCKET
74+
)
75+
76+
def tpu_strategy(self):
77+
run_kwargs = dict(
78+
chief_config=tfc.COMMON_MACHINE_CONFIGS["CPU"],
79+
worker_count=1,
80+
worker_config=tfc.COMMON_MACHINE_CONFIGS["TPU"],
81+
requirements_txt=self.requirements_txt,
82+
job_labels={
83+
"job": "tpu_strategy",
84+
"team": "run_experiment_cloud_tests",
85+
},
86+
docker_config=self.docker_config,
87+
)
88+
run_experiment_kwargs = dict(
89+
model_dir=os.path.join(_BASE_PATH, "tpu", "saved_model"),
90+
**self.run_experiment_kwargs,
91+
)
92+
return models.run_experiment_cloud(run_experiment_kwargs,
93+
run_kwargs)
94+
95+
def multi_mirror_strategy(self):
96+
run_kwargs = dict(
97+
chief_config=tfc.COMMON_MACHINE_CONFIGS["P100_1X"],
98+
worker_count=1,
99+
worker_config=tfc.COMMON_MACHINE_CONFIGS["P100_1X"],
100+
requirements_txt=self.requirements_txt,
101+
job_labels={
102+
"job": "multi_mirror_strategy",
103+
"team": "run_experiment_cloud_tests",
104+
},
105+
docker_config=self.docker_config,
106+
)
107+
run_experiment_kwargs = dict(
108+
model_dir=os.path.join(_BASE_PATH, "multi_mirror", "saved_model"),
109+
**self.run_experiment_kwargs,
110+
)
111+
return models.run_experiment_cloud(run_experiment_kwargs,
112+
run_kwargs)
113+
114+
def mirror_strategy(self):
115+
run_kwargs = dict(
116+
chief_config=tfc.COMMON_MACHINE_CONFIGS["P100_4X"],
117+
requirements_txt=self.requirements_txt,
118+
job_labels={
119+
"job": "mirror",
120+
"team": "run_experiment_cloud_tests",
121+
},
122+
docker_config=self.docker_config,
123+
)
124+
run_experiment_kwargs = dict(
125+
model_dir=os.path.join(_BASE_PATH, "mirror", "saved_model"),
126+
**self.run_experiment_kwargs,
127+
)
128+
return models.run_experiment_cloud(run_experiment_kwargs,
129+
run_kwargs)
130+
131+
def one_device_strategy(self):
132+
run_kwargs = dict(
133+
requirements_txt=self.requirements_txt,
134+
job_labels={
135+
"job": "one_device",
136+
"team": "run_experiment_cloud_tests",
137+
},
138+
docker_config=self.docker_config,
139+
)
140+
run_experiment_kwargs = dict(
141+
model_dir=os.path.join(_BASE_PATH, "one_device", "saved_model"),
142+
**self.run_experiment_kwargs,
143+
)
144+
# Using the default T4 GPU for this test.
145+
return models.run_experiment_cloud(run_experiment_kwargs,
146+
run_kwargs)
147+
148+
def test_run_experiment_cloud(self):
149+
track_status = {
150+
"one_device_strategy": self.one_device_strategy(),
151+
"mirror_strategy": self.mirror_strategy(),
152+
# TODO(b/148619319) Enable when bug is solved
153+
# "multi_mirror_strategy": self.multi_mirror_strategy(),
154+
# TODO(b/194857231) Enable when bug is solved
155+
# "tpu_strategy": self.tpu_strategy(),
156+
}
157+
158+
for test_name, ret_val in track_status.items():
159+
self.assertTrue(
160+
google_api_client.wait_for_aip_training_job_completion(
161+
ret_val["job_id"], _PROJECT_ID),
162+
"Job {} generated from the test: {} has failed".format(
163+
ret_val["job_id"], test_name))
164+
165+
if __name__ == "__main__":
166+
tf.test.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
git+https://github.com/tensorflow/cloud.git@refs/pull/360/head#egg=tensorflow-cloud&subdirectory=src/python
2+
tf-models-official

0 commit comments

Comments
 (0)