Skip to content

Commit fccc229

Browse files
Change location of TLT saved objects for distributed PYT (#344)
* Changed behaviour of saving torch dist. objects * Review changes * Removed val_data as it is not used anywhere
1 parent 4167440 commit fccc229

File tree

5 files changed

+115
-60
lines changed

5 files changed

+115
-60
lines changed

tlt/distributed/pytorch/run_train_pyt.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# SPDX-License-Identifier: Apache-2.0
1919
#
2020

21+
import os
2122
import argparse
2223

2324
from tlt.distributed.pytorch.utils.pyt_distributed_utils import (
@@ -28,7 +29,12 @@
2829

2930
if __name__ == "__main__":
3031

31-
# Program arguments
32+
def directory_path(path):
33+
if os.path.isdir(path):
34+
return path
35+
else:
36+
raise argparse.ArgumentTypeError("'{}' is not a valid directory path.".format(path))
37+
3238
print("******Distributed Training*****")
3339

3440
description = 'Distributed training with PyTorch.'
@@ -46,23 +52,25 @@
4652
help='Global batch size to distribute data (default: 128)')
4753
parser.add_argument('--disable_ipex', action='store_true', required=False, help="Disables IPEX optimization to "
4854
"the model")
55+
parser.add_argument('--tlt_saved_objects_dir', type=directory_path, required=False, help='Path to TLT saved '
56+
'distributed objects. The path must be accessible to all the nodes. For example: mounted '
57+
'NFS drive. This arg is helpful when using TLT API/CLI. '
58+
'See DistributedTorch.load_saved_objects() for more information.')
4959

5060
args = parser.parse_args()
5161

52-
# Load the saved dataset and model objects
53-
loaded_objects = DistributedTorch.load_saved_objects(use_case=args.use_case)
62+
if args.tlt_saved_objects_dir is not None:
63+
# Load the saved dataset and model objects
64+
loaded_objects = DistributedTorch.load_saved_objects(args.tlt_saved_objects_dir)
5465

55-
dataset = loaded_objects['dataset']
56-
train_subset = loaded_objects.get('train_subset', dataset)
57-
test_subset = loaded_objects.get('test_subset', dataset)
58-
validation_subset = loaded_objects.get('validation_subset', dataset)
59-
model = loaded_objects['model']
60-
loss = loaded_objects['loss']
61-
optimizer = loaded_objects['optimizer']
66+
train_data = loaded_objects.get('train_data')
67+
model = loaded_objects['model']
68+
loss = loaded_objects['loss']
69+
optimizer = loaded_objects['optimizer']
6270

6371
# Launch distributed job
6472
training_args = DistributedTrainingArguments(
65-
dataset=train_subset,
73+
dataset=train_data,
6674
model=model,
6775
criterion=loss,
6876
optimizer=optimizer,

tlt/distributed/pytorch/utils/pyt_distributed_utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from random import Random
2828
from torch.utils.data import DataLoader
2929
from torch.nn.parallel import DistributedDataParallel as DDP
30-
from tlt.distributed import TLT_DISTRIBUTED_DIR
3130

3231
import oneccl_bindings_for_pytorch # noqa # pylint: disable=unused-import
3332
import intel_extension_for_pytorch as ipex
@@ -245,7 +244,7 @@ def cleanup_ddp(cls):
245244
dist.destroy_process_group()
246245

247246
@classmethod
248-
def load_saved_objects(cls, use_case: str):
247+
def load_saved_objects(cls, saved_objects_dir):
249248
"""
250249
Helper function to load saved dataset and model objects
251250
@@ -255,11 +254,6 @@ def load_saved_objects(cls, use_case: str):
255254
Returns:
256255
dict with loaded dataset and model objects
257256
"""
258-
if use_case == 'text_classification':
259-
saved_objects_file = 'hf_saved_objects.obj'
260-
elif use_case == 'image_classification':
261-
saved_objects_file = 'torch_saved_objects.obj'
262-
else:
263-
raise ValueError("Distributed PyTorch for {} is not implemented yet".format(use_case))
257+
saved_objects_file = 'torch_saved_objects.obj'
264258

265-
return torch.load(os.path.join(TLT_DISTRIBUTED_DIR, saved_objects_file))
259+
return torch.load(os.path.join(saved_objects_dir, saved_objects_file))

tlt/models/image_classification/pytorch_image_classification_model.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import time
2424
import dill
2525
import subprocess
26+
import tempfile
27+
import shutil
2628

2729
from tqdm import tqdm
2830

@@ -244,7 +246,7 @@ def _fit(self, output_dir, dataset, epochs, do_eval, early_stopping, lr_decay):
244246
'loss': train_epoch_loss,
245247
}, os.path.join(checkpoint_dir, 'checkpoint.pt'))
246248

247-
def _fit_distributed(self, hostfile, nnodes, nproc_per_node, epochs, batch_size, ipex_optimize):
249+
def _fit_distributed(self, saved_objects_dir, hostfile, nnodes, nproc_per_node, epochs, batch_size, ipex_optimize):
248250
distributed_vision_script = os.path.join(TLT_DISTRIBUTED_DIR, "pytorch", "run_train_pyt.py")
249251

250252
default_port = '29500'
@@ -286,6 +288,7 @@ def _fit_distributed(self, hostfile, nnodes, nproc_per_node, epochs, batch_size,
286288
bash_command += ' --master_addr {}'.format(default_master_addr)
287289
bash_command += ' --master_port {}'.format(default_port)
288290
bash_command += ' --backend {}'.format('ccl')
291+
bash_command += ' --tlt_saved_objects_dir {}'.format(saved_objects_dir)
289292
bash_command += ' --use_case {}'.format('image_classification')
290293
bash_command += ' --epochs {}'.format(epochs)
291294
bash_command += ' --batch_size {}'.format(batch_size)
@@ -346,9 +349,19 @@ def train(self, dataset: ImageClassificationDataset, output_dir, epochs=1, initi
346349
self._optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
347350

348351
if distributed:
349-
self.export_for_distributed(TLT_DISTRIBUTED_DIR, dataset)
350-
batch_size = dataset._preprocessed['batch_size']
351-
self._fit_distributed(hostfile, nnodes, nproc_per_node, epochs, batch_size, ipex_optimize)
352+
try:
353+
saved_objects_dir = self.export_for_distributed(
354+
export_dir=os.path.join(output_dir, 'tlt_saved_objects'),
355+
train_data=dataset.train_subset,
356+
val_data=dataset.validation_subset
357+
)
358+
batch_size = dataset._preprocessed['batch_size']
359+
self._fit_distributed(saved_objects_dir, hostfile, nnodes, nproc_per_node, epochs, batch_size,
360+
ipex_optimize)
361+
except Exception as err:
362+
print("Error: \'{}\' occured while distributed training".format(err))
363+
finally:
364+
self.cleanup_saved_objects_for_distributed()
352365

353366
else:
354367
# Call ipex.optimize
@@ -467,26 +480,37 @@ def export(self, output_dir):
467480
else:
468481
raise ValueError("Unable to export the model, because it hasn't been trained yet")
469482

470-
def export_for_distributed(self, output_dir, dataset):
483+
def export_for_distributed(self, export_dir=None, train_data=None, val_data=None):
471484
"""
472-
Helper function to export dataset and model objects to disk for distributed job
485+
Exports the model, optimizer, loss, train data and validation data to the export_dir for distributed
486+
script to access. Note that the export_dir must be accessible to all the nodes. For example: NFS shared
487+
systems. Note that the export_dir is created using mkdtemp which reults in a unique dir name. For
488+
example: "<export_dir_Am83Iw". If the export_dir is None, the default name is "saved_objects"
473489
474490
Args:
475-
output_dir (str): Path to a directory where the dataset and model objects are saved.
476-
Default file name for saving the objects is "torch_saved_objects.obj"
477-
dataset (ImageClassificationDataset): Dataset object to save. It must be an object of
478-
ImageClassificationDataset so that the dataset info, train, test, and validation
479-
subsets can be accessed.
491+
export_dir (str): Directory name to export the model, optimizer, loss, train data and validation
492+
data. export_dir must be accessible to all the nodes. For example: NFS shared systems. export_dir
493+
is created using mkdtemp which reults in a unique dir name. For example: "<export_dir_Am83Iw".
494+
If the export_dir is None, the default name is "saved_objects"
495+
train_data (PyTorchDataset): Train dataset
496+
val_data (PyTorchDataset): Validation dataset
480497
"""
481498

499+
temp_dir_prefix = os.path.join(os.environ['HOME'], "saved_objects_") if export_dir is None else export_dir + "_"
500+
self._temp_dir = tempfile.mkdtemp(prefix=temp_dir_prefix)
501+
482502
objects_to_save = {
483-
"dataset": dataset.dataset,
484-
"info": dataset.info,
485-
"train_subset": dataset.train_subset,
486-
"test_subset": dataset.test_subset,
487-
"validation_subset": dataset.validation_subset,
503+
"train_data": train_data,
488504
"model": self._model,
489505
"optimizer": self._optimizer,
490506
"loss": self._loss
491507
}
492-
torch.save(objects_to_save, os.path.join(output_dir, "torch_saved_objects.obj"))
508+
torch.save(objects_to_save, os.path.join(self._temp_dir, "torch_saved_objects.obj"))
509+
return self._temp_dir
510+
511+
def cleanup_saved_objects_for_distributed(self):
512+
try:
513+
print('Cleaning saved objects...')
514+
shutil.rmtree(self._temp_dir)
515+
except OSError as ose:
516+
print('Error while cleaning the saved objects: {}'.format(ose))

tlt/models/image_classification/torchvision_image_classification_model.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
from downloader.models import ModelDownloader
2828
from tlt import TLT_BASE_DIR
29-
from tlt.distributed import TLT_DISTRIBUTED_DIR
3029
from tlt.models.image_classification.pytorch_image_classification_model import PyTorchImageClassificationModel
3130
from tlt.datasets.image_classification.image_classification_dataset import ImageClassificationDataset
3231
from tlt.utils.file_utils import read_json_file
@@ -186,9 +185,19 @@ def train(self, dataset: ImageClassificationDataset, output_dir, epochs=1, initi
186185
self._model, self._optimizer = ipex.optimize(self._model, optimizer=self._optimizer)
187186

188187
if distributed:
189-
self.export_for_distributed(TLT_DISTRIBUTED_DIR, dataset)
190-
batch_size = dataset._preprocessed['batch_size']
191-
self._fit_distributed(hostfile, nnodes, nproc_per_node, epochs, batch_size, ipex_optimize)
188+
try:
189+
saved_objects_dir = self.export_for_distributed(
190+
export_dir=os.path.join(output_dir, 'tlt_saved_objects'),
191+
train_data=dataset.train_subset,
192+
val_data=dataset.validation_subset
193+
)
194+
batch_size = dataset._preprocessed['batch_size']
195+
self._fit_distributed(saved_objects_dir, hostfile, nnodes, nproc_per_node, epochs, batch_size,
196+
ipex_optimize)
197+
except Exception as err:
198+
print("Error: \'{}\' occured while distributed training".format(err))
199+
finally:
200+
self.cleanup_saved_objects_for_distributed()
192201
else:
193202
self._model.train()
194203
self._fit(output_dir, dataset, epochs, do_eval, early_stopping, lr_decay)

tlt/models/text_classification/pytorch_hf_text_classification_model.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from tqdm import tqdm
3131
from torch.utils.data import DataLoader
3232
import yaml
33+
import tempfile
34+
import shutil
3335

3436
# Hugging Face imports
3537
from transformers import (
@@ -101,29 +103,39 @@ def __init__(self, model_name: str, model=None, optimizer=None, loss=None, **kwa
101103
self._trainer = None
102104
self._history = None
103105

104-
def export_for_distributed(self, output_dir, dataset):
106+
def export_for_distributed(self, export_dir, train_data=None, val_data=None):
105107
"""
106-
Helper function to export dataset and model objects to disk for distributed job
108+
Exports the model, optimizer, loss, train data and validation data to the export_dir for distributed
109+
script to access. Note that the export_dir must be accessible to all the nodes. For example: NFS shared
110+
systems. Note that the export_dir is created using mkdtemp which reults in a unique dir name. For
111+
example: "<export_dir_Am83Iw". If the export_dir is None, the default name is "saved_objects"
107112
108113
Args:
109-
output_dir (str): Path to a directory where the dataset and model objects are saved.
110-
Default file name for saving the objects is "hf_saved_objects.obj"
111-
dataset (HFTextClassificationDataset): Dataset object to save. It must be an object of
112-
HFTextClassificationDataset so that the dataset info, train, test, and validation
113-
subsets can be accessed.
114+
export_dir (str): Directory name to export the model, optimizer, loss, train data and validation
115+
data. export_dir must be accessible to all the nodes. For example: NFS shared systems. export_dir
116+
is created using mkdtemp which reults in a unique dir name. For example: "<export_dir_Am83Iw".
117+
If the export_dir is None, the default name is "saved_objects"
118+
train_data (PyTorchDataset): Train dataset
119+
val_data (PyTorchDataset): Validation dataset
114120
"""
121+
temp_dir_prefix = os.path.join(os.environ['HOME'], "saved_objects_") if export_dir is None else export_dir + "_"
122+
self._temp_dir = tempfile.mkdtemp(prefix=temp_dir_prefix)
115123

116124
objects_to_save = {
117-
"dataset": dataset.dataset,
118-
"info": dataset.info,
119-
"train_subset": dataset.train_subset,
120-
"test_subset": dataset.test_subset,
121-
"validation_subset": dataset.validation_subset,
125+
"train_data": train_data,
122126
"model": self._model,
123127
"optimizer": self._optimizer,
124128
"loss": self._loss
125129
}
126-
torch.save(objects_to_save, os.path.join(output_dir, "hf_saved_objects.obj"))
130+
torch.save(objects_to_save, os.path.join(self._temp_dir, "torch_saved_objects.obj"))
131+
return self._temp_dir
132+
133+
def cleanup_saved_objects_for_distributed(self):
134+
try:
135+
print('Cleaning saved objects...')
136+
shutil.rmtree(self._temp_dir)
137+
except OSError as ose:
138+
print('Error while cleaning the saved objects: {}'.format(ose))
127139

128140
@property
129141
def num_classes(self):
@@ -272,7 +284,7 @@ def _fit(self, output_dir, dataset, epochs, do_eval, early_stopping, lr_decay):
272284
'loss': train_epoch_loss,
273285
}, os.path.join(checkpoint_dir, 'checkpoint.pt'))
274286

275-
def _fit_distributed(self, hostfile, nnodes, nproc_per_node, epochs, batch_size, ipex_optimize):
287+
def _fit_distributed(self, saved_objects_dir, hostfile, nnodes, nproc_per_node, epochs, batch_size, ipex_optimize):
276288
distributed_text_script = os.path.join(TLT_DISTRIBUTED_DIR, "pytorch", "run_train_pyt.py")
277289

278290
default_port = '29500'
@@ -314,6 +326,7 @@ def _fit_distributed(self, hostfile, nnodes, nproc_per_node, epochs, batch_size,
314326
bash_command += ' --master_addr {}'.format(default_master_addr)
315327
bash_command += ' --master_port {}'.format(default_port)
316328
bash_command += ' --backend {}'.format('ccl')
329+
bash_command += ' --tlt_saved_objects_dir {}'.format(saved_objects_dir)
317330
bash_command += ' --use_case {}'.format('text_classification')
318331
bash_command += ' --epochs {}'.format(epochs)
319332
bash_command += ' --batch_size {}'.format(batch_size)
@@ -467,11 +480,18 @@ def compute_metrics(p: EvalPrediction):
467480
self._history = self._trainer.evaluate()
468481
print("Val Acc: {:.5f}".format(self._history.get("eval_accuracy")))
469482
elif distributed:
470-
self.export_for_distributed(
471-
output_dir=TLT_DISTRIBUTED_DIR, dataset=dataset
472-
)
473-
self._fit_distributed(hostfile, nnodes, nproc_per_node, epochs, dataset._preprocessed["batch_size"],
474-
ipex_optimize)
483+
try:
484+
saved_objects_dir = self.export_for_distributed(
485+
export_dir=os.path.join(output_dir, 'tlt_saved_objects'),
486+
train_data=dataset.train_subset,
487+
val_data=dataset.validation_subset
488+
)
489+
self._fit_distributed(saved_objects_dir, hostfile, nnodes, nproc_per_node, epochs,
490+
dataset._preprocessed["batch_size"], ipex_optimize)
491+
except Exception as err:
492+
print("Error: \'{}\' occured while distributed training".format(err))
493+
finally:
494+
self.cleanup_saved_objects_for_distributed()
475495
else:
476496
self._trainer = None
477497
self._model.train()

0 commit comments

Comments
 (0)