Skip to content

Commit cae5c3c

Browse files
committed
Add user guide
1 parent 03f8ba8 commit cae5c3c

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed

USER_GUIDE.md

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# User Guide: Running HuggingFace Llama Training on Cloud TPUs
2+
3+
4+
This user guide provides a concise overview of the essential steps required to run HuggingFace (HF) Llama training on Cloud TPUs.
5+
6+
7+
## Environment Setup
8+
9+
The following setup assumes to run the training job with Llama 2 7B on GCE TPUs. Please follow corresponding TPU generation's user guide to setup the GCE TPUs. For GKE users, most of the commands below also apply.
10+
11+
### Setup Environment of Your TPUs
12+
Please replace all your-* with your TPUs' information.
13+
```
14+
export TPU_NAME=your-tpu-name
15+
export ZONE=your-tpu-zone
16+
export PROJECT=your-tpu-project
17+
```
18+
19+
### HF Mixtral 7 x 8B Environment Setup
20+
21+
Here both PyTorch and PyTorch/XLA nightly are used with our fork of HuggingFace.
22+
```
23+
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
24+
--zone ${ZONE} \
25+
--project ${PROJECT} \
26+
--worker=all \
27+
--command='
28+
# Step 1: install torch, torch-xla, libtpu, pallas
29+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
30+
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
31+
pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
32+
pip3 install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
33+
34+
# Step 2: install HF
35+
git clone -b alanwaketan/flash_attention https://github.com/pytorch-tpu/transformers.git
36+
cd transformers
37+
pip3 install git+file://$PWD
38+
pip3 install accelerate datasets evaluate scikit-learn huggingface-hub
39+
'
40+
```
41+
42+
The next step is to sign into HF such that you can get accesses to the tokenizer or model checkpoints. Please replace `your_token` with your HF token.
43+
```
44+
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
45+
--zone ${ZONE} \
46+
--project ${PROJECT} \
47+
--worker=all \
48+
--command='
49+
export PATH=$PATH:/home/$USER/.local/bin
50+
huggingface-cli login --token your_token
51+
'
52+
```
53+
54+
The next step for HF setup is to copy your [Llama config](https://huggingface.co/meta-llama/Llama-2-7b-hf) into the TPU VM.
55+
```
56+
gcloud compute tpus tpu-vm scp Llama7B.json $TPU_NAME:~/config.json --worker all --project $PROJECT --zone=$ZONE
57+
```
58+
59+
The last step for HF setup is to copy your fsdp_config.json into the TPU VM.
60+
```
61+
{
62+
"fsdp_transformer_layer_cls_to_wrap": [
63+
"MixtralDecoderLayer"
64+
],
65+
"xla": true,
66+
"xla_fsdp_v2": true,
67+
"xla_fsdp_grad_ckpt": true
68+
}
69+
70+
```
71+
And the command to copy the config.
72+
```
73+
gcloud compute tpus tpu-vm scp fsdp_config.json $TPU_NAME:~/fsdp_config.json --worker all --project $PROJECT --zone=$ZONE
74+
```
75+
76+
## Steps to Run HF Llama 2 7B
77+
Following is the gcloud ssh command to run the training job from the host:
78+
```
79+
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
80+
--zone ${ZONE} \
81+
--project ${PROJECT} \
82+
--worker=all \
83+
--command='
84+
# Setup envs
85+
export PJRT_DEVICE=TPU
86+
export XLA_USE_SPMD=1
87+
export XLA_IR_DEBUG=1
88+
export XLA_HLO_DEBUG=1
89+
90+
export PROFILE_EPOCH=0
91+
export PROFILE_STEP=3
92+
export PROFILE_DURATION_MS=20000
93+
export PROFILE_LOGDIR=/tmp/home/
94+
95+
# Run
96+
cd transformers
97+
python3 examples/pytorch/language-modeling/run_clm.py \
98+
--dataset_name wikitext \
99+
--dataset_config_name wikitext-2-raw-v1 \
100+
--per_device_train_batch_size 8 \
101+
--do_train \
102+
--output_dir /tmp/test-clm \
103+
--overwrite_output_dir \
104+
--config_name ~/config.json \
105+
--cache_dir /tmp \
106+
--tokenizer_name meta-llama/Llama-2-7b-hf \
107+
--block_size 4096 \
108+
--optim adafactor \
109+
--save_strategy no \
110+
--logging_strategy no \
111+
--fsdp "full_shard" \
112+
--fsdp_config ~/fsdp_config.json \
113+
--torch_dtype bfloat16 \
114+
--dataloader_drop_last yes \
115+
--flash_attention \
116+
--max_steps 5 \
117+
'
118+
```
119+
120+
121+
### Environment Envs Explained
122+
123+
124+
125+
* `PJRT_DEVICE`: Specify the XLA device.
126+
* `XLA_USE_SPMD`: Turn on GSPMD.
127+
* `XLA_IR_DEBUG`: Capture Python stack trace in Lazy IRs.
128+
* `XLA_HLO_DEBUG`: Capture Python stack trace in HLOs.
129+
* `PROFILE_EPOCH`: Specify which epoch to start taking the profile.
130+
* `PROFILE_STEP`: Specify which step to start taking the profile.
131+
* `PROFILE_DURATION_MS`: Specify how long the profiling will last.
132+
* `PROFILE_LOGDIR`: Specify where to put the profiling results.
133+
134+
135+
### HF Mixtral Arguments Explained
136+
137+
138+
139+
* `--flash_attention`: [bool] Enable Pallas FlashAttention. Default: False.
140+
* `--per_device_train_batch_size`: [int] Specify the global batch size. GSPMD treats the program as a singel device program.
141+
* `--num_train_epochs`: [int] Specify the total number of epochs. If the total steps is too large, try setting `--max_steps` instead to speed up the experiment.
142+
143+
## How to measure the step time?
144+
A profile will be captured in `/tmp/home/`. Just use TensorBoard to open the profile and measure the step time from the "Trace View."

0 commit comments

Comments
 (0)