|
| 1 | +# User Guide: Running HuggingFace Llama Training on Cloud TPUs |
| 2 | + |
| 3 | + |
| 4 | +This user guide provides a concise overview of the essential steps required to run HuggingFace (HF) Llama training on Cloud TPUs. |
| 5 | + |
| 6 | + |
| 7 | +## Environment Setup |
| 8 | + |
| 9 | +The following setup assumes to run the training job with Llama 2 7B on GCE TPUs. Please follow corresponding TPU generation's user guide to setup the GCE TPUs. For GKE users, most of the commands below also apply. |
| 10 | + |
| 11 | +### Setup Environment of Your TPUs |
| 12 | +Please replace all your-* with your TPUs' information. |
| 13 | +``` |
| 14 | +export TPU_NAME=your-tpu-name |
| 15 | +export ZONE=your-tpu-zone |
| 16 | +export PROJECT=your-tpu-project |
| 17 | +``` |
| 18 | + |
| 19 | +### HF Mixtral 7 x 8B Environment Setup |
| 20 | + |
| 21 | +Here both PyTorch and PyTorch/XLA nightly are used with our fork of HuggingFace. |
| 22 | +``` |
| 23 | +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ |
| 24 | +--zone ${ZONE} \ |
| 25 | +--project ${PROJECT} \ |
| 26 | +--worker=all \ |
| 27 | +--command=' |
| 28 | +# Step 1: install torch, torch-xla, libtpu, pallas |
| 29 | +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu |
| 30 | +pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl |
| 31 | +pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html |
| 32 | +pip3 install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html |
| 33 | +
|
| 34 | +# Step 2: install HF |
| 35 | +git clone -b alanwaketan/flash_attention https://github.com/pytorch-tpu/transformers.git |
| 36 | +cd transformers |
| 37 | +pip3 install git+file://$PWD |
| 38 | +pip3 install accelerate datasets evaluate scikit-learn huggingface-hub |
| 39 | +' |
| 40 | +``` |
| 41 | + |
| 42 | +The next step is to sign into HF such that you can get accesses to the tokenizer or model checkpoints. Please replace `your_token` with your HF token. |
| 43 | +``` |
| 44 | +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ |
| 45 | +--zone ${ZONE} \ |
| 46 | +--project ${PROJECT} \ |
| 47 | +--worker=all \ |
| 48 | +--command=' |
| 49 | +export PATH=$PATH:/home/$USER/.local/bin |
| 50 | +huggingface-cli login --token your_token |
| 51 | +' |
| 52 | +``` |
| 53 | + |
| 54 | +The next step for HF setup is to copy your [Llama config](https://huggingface.co/meta-llama/Llama-2-7b-hf) into the TPU VM. |
| 55 | +``` |
| 56 | +gcloud compute tpus tpu-vm scp Llama7B.json $TPU_NAME:~/config.json --worker all --project $PROJECT --zone=$ZONE |
| 57 | +``` |
| 58 | + |
| 59 | +The last step for HF setup is to copy your fsdp_config.json into the TPU VM. |
| 60 | +``` |
| 61 | +{ |
| 62 | + "fsdp_transformer_layer_cls_to_wrap": [ |
| 63 | + "MixtralDecoderLayer" |
| 64 | + ], |
| 65 | + "xla": true, |
| 66 | + "xla_fsdp_v2": true, |
| 67 | + "xla_fsdp_grad_ckpt": true |
| 68 | +} |
| 69 | +
|
| 70 | +``` |
| 71 | +And the command to copy the config. |
| 72 | +``` |
| 73 | +gcloud compute tpus tpu-vm scp fsdp_config.json $TPU_NAME:~/fsdp_config.json --worker all --project $PROJECT --zone=$ZONE |
| 74 | +``` |
| 75 | + |
| 76 | +## Steps to Run HF Llama 2 7B |
| 77 | +Following is the gcloud ssh command to run the training job from the host: |
| 78 | +``` |
| 79 | +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ |
| 80 | +--zone ${ZONE} \ |
| 81 | +--project ${PROJECT} \ |
| 82 | +--worker=all \ |
| 83 | +--command=' |
| 84 | +# Setup envs |
| 85 | +export PJRT_DEVICE=TPU |
| 86 | +export XLA_USE_SPMD=1 |
| 87 | +export XLA_IR_DEBUG=1 |
| 88 | +export XLA_HLO_DEBUG=1 |
| 89 | +
|
| 90 | +export PROFILE_EPOCH=0 |
| 91 | +export PROFILE_STEP=3 |
| 92 | +export PROFILE_DURATION_MS=20000 |
| 93 | +export PROFILE_LOGDIR=/tmp/home/ |
| 94 | +
|
| 95 | +# Run |
| 96 | +cd transformers |
| 97 | +python3 examples/pytorch/language-modeling/run_clm.py \ |
| 98 | + --dataset_name wikitext \ |
| 99 | + --dataset_config_name wikitext-2-raw-v1 \ |
| 100 | + --per_device_train_batch_size 8 \ |
| 101 | + --do_train \ |
| 102 | + --output_dir /tmp/test-clm \ |
| 103 | + --overwrite_output_dir \ |
| 104 | + --config_name ~/config.json \ |
| 105 | + --cache_dir /tmp \ |
| 106 | + --tokenizer_name meta-llama/Llama-2-7b-hf \ |
| 107 | + --block_size 4096 \ |
| 108 | + --optim adafactor \ |
| 109 | + --save_strategy no \ |
| 110 | + --logging_strategy no \ |
| 111 | + --fsdp "full_shard" \ |
| 112 | + --fsdp_config ~/fsdp_config.json \ |
| 113 | + --torch_dtype bfloat16 \ |
| 114 | + --dataloader_drop_last yes \ |
| 115 | + --flash_attention \ |
| 116 | + --max_steps 5 \ |
| 117 | +' |
| 118 | +``` |
| 119 | + |
| 120 | + |
| 121 | +### Environment Envs Explained |
| 122 | + |
| 123 | + |
| 124 | + |
| 125 | +* `PJRT_DEVICE`: Specify the XLA device. |
| 126 | +* `XLA_USE_SPMD`: Turn on GSPMD. |
| 127 | +* `XLA_IR_DEBUG`: Capture Python stack trace in Lazy IRs. |
| 128 | +* `XLA_HLO_DEBUG`: Capture Python stack trace in HLOs. |
| 129 | +* `PROFILE_EPOCH`: Specify which epoch to start taking the profile. |
| 130 | +* `PROFILE_STEP`: Specify which step to start taking the profile. |
| 131 | +* `PROFILE_DURATION_MS`: Specify how long the profiling will last. |
| 132 | +* `PROFILE_LOGDIR`: Specify where to put the profiling results. |
| 133 | + |
| 134 | + |
| 135 | +### HF Mixtral Arguments Explained |
| 136 | + |
| 137 | + |
| 138 | + |
| 139 | +* `--flash_attention`: [bool] Enable Pallas FlashAttention. Default: False. |
| 140 | +* `--per_device_train_batch_size`: [int] Specify the global batch size. GSPMD treats the program as a singel device program. |
| 141 | +* `--num_train_epochs`: [int] Specify the total number of epochs. If the total steps is too large, try setting `--max_steps` instead to speed up the experiment. |
| 142 | + |
| 143 | +## How to measure the step time? |
| 144 | +A profile will be captured in `/tmp/home/`. Just use TensorBoard to open the profile and measure the step time from the "Trace View." |
0 commit comments