Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions main_with_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from optim import adamw
from optim import nadamw

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser = argparse.ArgumentParser(description='AsyncPP Language Model Training')
parser.add_argument('--data_dir', '-dd', type=str, default='~/data',
help='path to dataset')
parser.add_argument('--dataset_name', '-d', type=str,
Expand Down Expand Up @@ -257,7 +257,8 @@ def main():
'stage_to_depth_map': None
}
if args.config_path is not None:
json_config_file = json.load(open(args.config_path, 'r'))
with open(args.config_path, 'r') as f:
json_config_file = json.load(f)
configuration_maps['module_to_stage_map'] = json_config_file.get("module_to_stage_map", None)
configuration_maps['stage_to_rank_map'] = json_config_file.get("stage_to_rank_map", None)
configuration_maps['stage_to_rank_map'] = {
Expand Down Expand Up @@ -318,7 +319,7 @@ def main():
checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage)
assert os.path.isfile(checkpoint_file_path)
print("=> loading checkpoint '{}'".format(checkpoint_file_path))
checkpoint = torch.load(checkpoint_file_path)
checkpoint = torch.load(checkpoint_file_path, weights_only=False)
args.start_epoch = checkpoint['epoch']
best_loss = checkpoint['best_loss']
r.load_state_dict(checkpoint['state_dict'])
Expand Down
Empty file added optim/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def append_to_queue(self, data):
def get_from_queue(self, index):
if self.save_dir is not None:
fname = self.queue[index]
d = torch.load(fname)
d = torch.load(fname, weights_only=False)
return d["state_dicts"], d["version"]
else:
return self.queue[index]
Expand Down
2 changes: 1 addition & 1 deletion run.bash
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ done
wait

# gpipe
basecmdstr="python sync_main.py $model $batch -d $d $dd --master_addr localhost --distributed_backend nccl
basecmdstr="python sync_main.py $model $batch -d $d --master_addr localhost --distributed_backend nccl
$lr $epochs $minibatches $cg $logtb --recompute --lr_policy cosine --optimizer adamw"

# gpipe
Expand Down
Empty file added runtime/__init__.py
Empty file.
5 changes: 3 additions & 2 deletions runtime/threadsafe_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# Licensed under the MIT license.

import threading
from collections import deque

"""
Implementation of a thread-safe queue with one producer and one consumer.
"""
class Queue:
def __init__(self):
self.queue = []
self.queue = deque()
self.cv = threading.Condition()

def add(self, tensor):
Expand All @@ -21,6 +22,6 @@ def remove(self):
self.cv.acquire()
while len(self.queue) == 0:
self.cv.wait()
tensor = self.queue.pop(0)
tensor = self.queue.popleft()
self.cv.release()
return tensor
7 changes: 4 additions & 3 deletions sync_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from data_utils import ShakespeareDataset, WikiTextDataset, OpenWebTextDataset, BookCorpusDataset, DataUtil
from transformers import AutoTokenizer

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser = argparse.ArgumentParser(description='AsyncPP Sync Baseline Training')
parser.add_argument('--data_dir', '-dd', type=str, default='~/data',
help='path to dataset')
parser.add_argument('--dataset_name', '-d', type=str,
Expand Down Expand Up @@ -300,7 +300,8 @@ def main():
'stage_to_depth_map': None
}
if args.config_path is not None:
json_config_file = json.load(open(args.config_path, 'r'))
with open(args.config_path, 'r') as f:
json_config_file = json.load(f)
configuration_maps['module_to_stage_map'] = json_config_file.get("module_to_stage_map", None)
configuration_maps['stage_to_rank_map'] = json_config_file.get("stage_to_rank_map", None)
configuration_maps['stage_to_rank_map'] = {
Expand Down Expand Up @@ -369,7 +370,7 @@ def main():
checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, args.stage)
assert os.path.isfile(checkpoint_file_path)
print("=> loading checkpoint '{}'".format(checkpoint_file_path))
checkpoint = torch.load(checkpoint_file_path)
checkpoint = torch.load(checkpoint_file_path, weights_only=False)
args.start_epoch = checkpoint['epoch']
best_loss = checkpoint['best_loss']
pp_stage.submod.load_state_dict(checkpoint['state_dict'])
Expand Down