diff --git a/main_with_runtime.py b/main_with_runtime.py index 9744817..39b8c50 100644 --- a/main_with_runtime.py +++ b/main_with_runtime.py @@ -28,7 +28,7 @@ from optim import adamw from optim import nadamw -parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser = argparse.ArgumentParser(description='AsyncPP Language Model Training') parser.add_argument('--data_dir', '-dd', type=str, default='~/data', help='path to dataset') parser.add_argument('--dataset_name', '-d', type=str, @@ -257,7 +257,8 @@ def main(): 'stage_to_depth_map': None } if args.config_path is not None: - json_config_file = json.load(open(args.config_path, 'r')) + with open(args.config_path, 'r') as f: + json_config_file = json.load(f) configuration_maps['module_to_stage_map'] = json_config_file.get("module_to_stage_map", None) configuration_maps['stage_to_rank_map'] = json_config_file.get("stage_to_rank_map", None) configuration_maps['stage_to_rank_map'] = { @@ -318,7 +319,7 @@ def main(): checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage) assert os.path.isfile(checkpoint_file_path) print("=> loading checkpoint '{}'".format(checkpoint_file_path)) - checkpoint = torch.load(checkpoint_file_path) + checkpoint = torch.load(checkpoint_file_path, weights_only=False) args.start_epoch = checkpoint['epoch'] best_loss = checkpoint['best_loss'] r.load_state_dict(checkpoint['state_dict']) diff --git a/optim/__init__.py b/optim/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/optim/optimizer.py b/optim/optimizer.py index 8db279e..ca564aa 100644 --- a/optim/optimizer.py +++ b/optim/optimizer.py @@ -84,7 +84,7 @@ def append_to_queue(self, data): def get_from_queue(self, index): if self.save_dir is not None: fname = self.queue[index] - d = torch.load(fname) + d = torch.load(fname, weights_only=False) return d["state_dicts"], d["version"] else: return self.queue[index] diff --git a/run.bash b/run.bash index e54e1d7..d87df30 100644 --- a/run.bash +++ b/run.bash @@ -27,7 +27,7 @@ done wait # gpipe -basecmdstr="python sync_main.py $model $batch -d $d $dd --master_addr localhost --distributed_backend nccl +basecmdstr="python sync_main.py $model $batch -d $d --master_addr localhost --distributed_backend nccl $lr $epochs $minibatches $cg $logtb --recompute --lr_policy cosine --optimizer adamw" # gpipe diff --git a/runtime/__init__.py b/runtime/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/runtime/threadsafe_queue.py b/runtime/threadsafe_queue.py index fae3c10..7c32775 100644 --- a/runtime/threadsafe_queue.py +++ b/runtime/threadsafe_queue.py @@ -2,13 +2,14 @@ # Licensed under the MIT license. import threading +from collections import deque """ Implementation of a thread-safe queue with one producer and one consumer. """ class Queue: def __init__(self): - self.queue = [] + self.queue = deque() self.cv = threading.Condition() def add(self, tensor): @@ -21,6 +22,6 @@ def remove(self): self.cv.acquire() while len(self.queue) == 0: self.cv.wait() - tensor = self.queue.pop(0) + tensor = self.queue.popleft() self.cv.release() return tensor diff --git a/sync_main.py b/sync_main.py index 708d0f8..bb2cfe5 100644 --- a/sync_main.py +++ b/sync_main.py @@ -32,7 +32,7 @@ from data_utils import ShakespeareDataset, WikiTextDataset, OpenWebTextDataset, BookCorpusDataset, DataUtil from transformers import AutoTokenizer -parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser = argparse.ArgumentParser(description='AsyncPP Sync Baseline Training') parser.add_argument('--data_dir', '-dd', type=str, default='~/data', help='path to dataset') parser.add_argument('--dataset_name', '-d', type=str, @@ -300,7 +300,8 @@ def main(): 'stage_to_depth_map': None } if args.config_path is not None: - json_config_file = json.load(open(args.config_path, 'r')) + with open(args.config_path, 'r') as f: + json_config_file = json.load(f) configuration_maps['module_to_stage_map'] = json_config_file.get("module_to_stage_map", None) configuration_maps['stage_to_rank_map'] = json_config_file.get("stage_to_rank_map", None) configuration_maps['stage_to_rank_map'] = { @@ -369,7 +370,7 @@ def main(): checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, args.stage) assert os.path.isfile(checkpoint_file_path) print("=> loading checkpoint '{}'".format(checkpoint_file_path)) - checkpoint = torch.load(checkpoint_file_path) + checkpoint = torch.load(checkpoint_file_path, weights_only=False) args.start_epoch = checkpoint['epoch'] best_loss = checkpoint['best_loss'] pp_stage.submod.load_state_dict(checkpoint['state_dict'])