Skip to content

Update transformers version and BERT script #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -69,6 +69,13 @@
import inspect
import requests
import gc


from torch_xla import __version__
version_tuple = tuple(map(int, __version__.split(".")[:2]))
is_pt21_plus = version_tuple >= (2,1)
is_pt20 = version_tuple == (2,0)

os.environ["NEURON_CC_FLAGS"] = os.environ.get('NEURON_CC_FLAGS', '') + " --model-type=transformer"

# For PT autocast.
@@ -221,16 +228,29 @@ def sequence_length(self) -> int:


def get_model(flags):
base_model = BertForPreTraining.from_pretrained('bert-large-uncased', force_download=True)
base_model = BertForPreTraining.from_pretrained('bert-large-uncased', use_safetensors=False)
# medium BERT size L12_A12_H768. Large BERT L24_A16_H1024 causes OOM on GPU V100
my_config = copy.deepcopy(base_model.config)
if flags.disable_dropout or flags.snapshot_step_list:
my_config.hidden_dropout_prob = 0.0
my_config.attention_probs_dropout_prob = 0.0
if flags.debug:
my_config.num_hidden_layers = 1
my_config.num_attention_heads = 2
my_config.hidden_size = 16
my_model = BertForPreTraining(my_config)
return my_model

def extract_mfu(num_layers, hidden_size, sequence_len, batch_size, average_throughput, world_size):
flops_per_seq = 12 * num_layers * hidden_size * sequence_len * (6 * hidden_size + sequence_len)
tflops_per_seq = flops_per_seq / 10**12
tflops_per_sec_per_worker = tflops_per_seq * average_throughput/world_size
if '--auto-cast=none' in os.getenv('NEURON_CC_FLAGS', default=''):
hw_tflops_per_worker = 760/32
else:
hw_tflops_per_worker = 3040/32
return tflops_per_sec_per_worker/hw_tflops_per_worker * 100

# fix NVidia checkpoint param names to match HF
def fix_ckpt_params(state_dict):
keys = [k for k in state_dict.keys() if 'dense_act' in k]
@@ -335,6 +355,32 @@ def train_bert_hdf5(flags):
}

def train_loop_fn(model, optimizer, train_loader, epoch, global_step, training_ustep, running_loss):

# Add snapshot callback here in order to track total_steps
total_steps = 0
capture_steps = []
# Turn off snapshoting for all ranks/steps by default, and select ranks/steps in specified lists
def callback(name, addressable_device_index, execution_count):
return ''
# Enable snapshoting for ranks/steps specified in lists
if flags.snapshot_step_list:
if flags.snapshot_rank_list == "all":
capture_ranks = [] # empty list means all ranks
else:
capture_ranks = [int(i) for i in flags.snapshot_rank_list.split(",")]
if capture_ranks == [] or xm.get_ordinal() in capture_ranks:
capture_steps = [int(i) for i in flags.snapshot_step_list.split(",")]
if is_pt21_plus:
print(f"Enabling snapshotting for rank{xm.get_ordinal()} and steps {capture_steps}")
def callback(name, addressable_device_index, execution_count):
if total_steps in capture_steps:
return 'inputs outputs'
else:
return ''
if is_pt21_plus:
import libneuronxla
libneuronxla.register_hlo_snapshot_callback(callback)

max_grad_norm = 1.0
running_loss_reduced_detached = torch.zeros(1, device=device)
for i, data in enumerate(train_loader):
@@ -504,6 +550,7 @@ def _print_logs(running_loss_reduced_detached, total_norm):
else:
chkpt_file = os.path.join(flags.output_dir, "ckpt_{}.pt".format(global_step))
files_info = [f] + files

print('Checkpointing...', flush=True)
model_to_save = model.module if hasattr(model, 'module') else model # unwrap model if needed (DDP)
if flags.minimal_ckpt:
@@ -526,32 +573,37 @@ def _print_logs(running_loss_reduced_detached, total_norm):
if os.path.isfile(old_file):
print('Keeping only {} checkpoints. Deleting {}'.format(flags.num_ckpts_to_keep, old_file))
os.remove(old_file)

if global_step >= flags.steps_this_run:
xm.rendezvous("before_throughput_check") # avoid multi-node hang due to throughput threshold assert by root worker
if is_root and not extract_graphs_only:
compile_time = 0.0
compile_time_file="compile_time.txt"
if os.path.exists(compile_time_file):
with open(compile_time_file, "r") as f:
compile_time = float(f.readline())
# record aggregate & final statistics in the metrics file
additional_data = {
"Epoch": epoch, "Global step": global_step, "Microstep": training_ustep
}
average_throughput = round(sum(logger.throughputs)/len(logger.throughputs), 4)
model_flops_utilization = extract_mfu(len(model.bert.encoder.layer), model.bert.config.hidden_size, train_dataloader.dataset.sequence_length, flags.batch_size, average_throughput, world_size)
metric_data = [
Metric("FinalLoss", final_loss, units="", additional=additional_data),
Metric("TimeToTrain", round(time_diff/60, 4), units="minutes", additional=additional_data),
Metric("Compile Time", compile_time, units="sec", additional=additional_data),
Metric("PeakThroughput", max(logger.throughputs), units="seq/s", additional=additional_data),
Metric("MFU", model_flops_utilization, units="%", additional=additional_data),
]
if(flags.expected_average_throughput > 0):
derived_expected_throughput = (0.95*flags.expected_average_throughput)
metric_data = [
Metric("FinalLoss", final_loss, units="", additional=additional_data),
Metric("TimeToTrain", round(time_diff/60, 4), units="minutes", additional=additional_data),
Metric("AverageThroughput", average_throughput, units="seq/s", expected=flags.expected_average_throughput, derived=(0.95*flags.expected_average_throughput) ,additional=additional_data),
Metric("PeakThroughput", max(logger.throughputs), units="seq/s", additional=additional_data)
]
metric_data.append(
Metric("AverageThroughput", average_throughput, units="seq/s", expected=flags.expected_average_throughput, derived=derived_expected_throughput, additional=additional_data))
post_metrics(metric_data, parameters=parameters)
assert( average_throughput >= derived_expected_throughput), "Average throughput :{} is below derived expected threshold: {}".format(average_throughput, derived_expected_throughput)
assert(average_throughput >= derived_expected_throughput), "Average throughput :{} is below derived expected threshold: {}".format(average_throughput, derived_expected_throughput)
else:

metric_data = [
Metric("FinalLoss", final_loss, units="", additional=additional_data),
Metric("TimeToTrain", round(time_diff/60, 4), units="minutes", additional=additional_data),
Metric("AverageThroughput", average_throughput, units="seq/s", additional=additional_data),
Metric("PeakThroughput", max(logger.throughputs), units="seq/s", additional=additional_data)
]
metric_data.append(
Metric("AverageThroughput", average_throughput, units="seq/s", additional=additional_data))
post_metrics(metric_data, parameters=parameters)
return
del train_device_loader
@@ -611,6 +663,11 @@ def _mp_fn(index, flags):
parser.add_argument('--phase2', default=False, action='store_true', help="Whether to train with seq len 512")
parser.add_argument('--print_grad_norm', default=False, action='store_true', help="Whether to print grad norm")
parser.add_argument('--expected_average_throughput', type=float, default=0.0, help="Expected average throughput")
parser.add_argument('--disable_dropout', default=False, action='store_true', help="Disable dropout")
parser.add_argument("--snapshot_step_list", default=None, help="comma separated list of steps to take snapshot; also used to enable snapshotting with dropout disabled (WARNNG: can take lots of disk space, esp with grad accum.)")
parser.add_argument("--snapshot_rank_list", default="0", help="comma separated list of ranks to take snapshot, or 'all' for all ranks (WARNNG: can take lots of disk space, esp with grad accum.)")
parser.add_argument("--snapshot_dump_dir", default="./snapshots", help="directory to dump snapshots; snapshot_step_list must be specified")

args = parser.parse_args(sys.argv[1:])

if args.steps_this_run < 0:
@@ -619,6 +676,18 @@ def _mp_fn(index, flags):
if args.enable_pt_autocast:
os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1"


# Enable HLO snapshot dump before device init (first use of 'xla' device)
# Will do more fine-grained enablement in the training function to track global step
if args.snapshot_step_list:
if is_pt21_plus:
print("Enabling snapshotting in dir: ", args.snapshot_dump_dir)
os.environ["XLA_FLAGS"] = f"--xla_dump_hlo_snapshots --xla_dump_to={args.snapshot_dump_dir}"
elif is_pt20:
print("WARNING: Snapshotting is not enabled for torch-neuronx 2.0beta; snapshot options are ignored.")
else:
print("WARNING: For torch-neuronx 1.13, please follow instructions in documentation to enable snapshotting.")

# WORLD_SIZE is set by torchrun
if os.environ.get("WORLD_SIZE"):
init_process_group()
12 changes: 6 additions & 6 deletions torch-neuronx/training/dp_bert_hf_pretrain/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
graphviz
tensorboard==2.6
transformers==4.26.0
tensorboard==2.14
transformers==4.44.0
evaluate
pillow
pytest
accelerate
datasets >= 1.8.0
sentencepiece != 0.1.92
datasets==2.19.1
sentencepiece==0.2.0
h5py
requests
huggingface-hub<0.23
requests==2.31.0
huggingface-hub==0.24.5
Original file line number Diff line number Diff line change
@@ -52,7 +52,11 @@ if [ ! -z "$SLURM_NTASKS" ]; then
CACHE_DIR=$HOME/neuron_cache/bert/`hostname`
export NEURON_CC_FLAGS="--cache_dir=$CACHE_DIR"
fi
export TRANSFORMERS_CACHE=$HOME/hf_cache/`hostname`/hub
export HF_HOME=/tmp/hf_cache/
mkdir -p $HF_HOME
if [ -e $HOME/.cache/huggingface ]; then
rsync -av $HOME/.cache/huggingface/ $HF_HOME
fi
# HF ver > 4.22: Move cache ahead of time to prevent multiple workers moving at the same time
python -c "import transformers.utils as utils; utils.move_cache()"
fi
Original file line number Diff line number Diff line change
@@ -53,7 +53,11 @@ if [ ! -z "$SLURM_NTASKS" ]; then
CACHE_DIR=$HOME/neuron_cache/bert/`hostname`
export NEURON_CC_FLAGS="--cache_dir=$CACHE_DIR"
fi
export TRANSFORMERS_CACHE=$HOME/hf_cache/`hostname`/hub
export HF_HOME=/tmp/hf_cache/
mkdir -p $HF_HOME
if [ -e $HOME/.cache/huggingface ]; then
rsync -av $HOME/.cache/huggingface/ $HF_HOME
fi
# HF ver > 4.22: Move cache ahead of time to prevent multiple workers moving at the same time
python -c "import transformers.utils as utils; utils.move_cache()"
fi
Original file line number Diff line number Diff line change
@@ -60,7 +60,11 @@ if [ ! -z "$SLURM_NTASKS" ]; then
CACHE_DIR=$HOME/neuron_cache/bert/`hostname`
export NEURON_CC_FLAGS="--cache_dir=$CACHE_DIR"
fi
export TRANSFORMERS_CACHE=$HOME/hf_cache/`hostname`/hub
export HF_HOME=/tmp/hf_cache/
mkdir -p $HF_HOME
if [ -e $HOME/.cache/huggingface ]; then
rsync -av $HOME/.cache/huggingface/ $HF_HOME
fi
# HF ver > 4.22: Move cache ahead of time to prevent multiple workers moving at the same time
python -c "import transformers.utils as utils; utils.move_cache()"
fi
Original file line number Diff line number Diff line change
@@ -60,7 +60,11 @@ if [ ! -z "$SLURM_NTASKS" ]; then
CACHE_DIR=$HOME/neuron_cache/bert/`hostname`
export NEURON_CC_FLAGS="--cache_dir=$CACHE_DIR"
fi
export TRANSFORMERS_CACHE=$HOME/hf_cache/`hostname`/hub
export HF_HOME=/tmp/hf_cache/
mkdir -p $HF_HOME
if [ -e $HOME/.cache/huggingface ]; then
rsync -av $HOME/.cache/huggingface/ $HF_HOME
fi
# HF ver > 4.22: Move cache ahead of time to prevent multiple workers moving at the same time
python -c "import transformers.utils as utils; utils.move_cache()"
fi