Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit c40de9b

Browse files
y-sqfacebook-github-bot
authored andcommitted
Compile test fsdp (#160)
Summary: Added an option in test_fsdp.py to compile fsdp. With compile mode, the numerics check can still pass. However, note that the compile now works with a workaround. And fullgraph needs to be False. We still need to fix the issue. When running "./test/test_fsdp.sh", three settings will be testet: 1. Fp8 = False 2. Fp8 = True, Compile = False 3. Fp8 = True, Compile = True (with fullgraph = False) For example: ``` $ ./test/test_fsdp.sh launching IS_FP8 False, compile_fsdp False, fullgraph False -------------------------------------------Mode: generate------------------------------------------- Success: ✅ ------------------------------------------Mode: single_gpu------------------------------------------ Success: ✅ ---------------------------------------------Mode: fsdp--------------------------------------------- NCCL version 2.19.3+cuda12.1 -------------------------------------------Mode: analyze-------------------------------------------- output testing single_gpu vs FSDP success state dict testing single_gpu vs FSDP success Success: ✅ ✅ All Tests Passed ✅ launching IS_FP8 True, compile_fsdp False, fullgraph False -------------------------------------------Mode: generate------------------------------------------- Success: ✅ ------------------------------------------Mode: single_gpu------------------------------------------ Success: ✅ ---------------------------------------------Mode: fsdp--------------------------------------------- NCCL version 2.19.3+cuda12.1 -------------------------------------------Mode: analyze-------------------------------------------- output testing single_gpu vs FSDP success state dict testing single_gpu vs FSDP success Success: ✅ ✅ All Tests Passed ✅ launching IS_FP8 True, compile_fsdp True, fullgraph False -------------------------------------------Mode: generate------------------------------------------- Success: ✅ ------------------------------------------Mode: single_gpu------------------------------------------ Success: ✅ ---------------------------------------------Mode: fsdp--------------------------------------------- NCCL version 2.19.3+cuda12.1 [rank0]:[2023-12-15 14:49:02,616] [0/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored [rank0]:[2023-12-15 14:49:02,618] [0/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored [rank1]:[2023-12-15 14:49:02,706] [0/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored [rank1]:[2023-12-15 14:49:02,708] [0/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored -------------------------------------------Mode: analyze-------------------------------------------- output testing single_gpu vs FSDP success state dict testing single_gpu vs FSDP success Success: ✅ ✅ All Tests Passed ✅ ``` Pull Request resolved: #160 Reviewed By: vkuzo Differential Revision: D52224302 Pulled By: y-sq fbshipit-source-id: 4c29479771f4cd100b8c5a9549d321eb13b49739
1 parent b41006b commit c40de9b

File tree

2 files changed

+46
-16
lines changed

2 files changed

+46
-16
lines changed

test/test_fsdp.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@
4646

4747
B, M, K, N = 8, 8, 32, 32
4848
lr = 0.01
49-
N_ITER = 3
50-
N_ITER = 1
49+
N_ITER = 5
5150

5251

5352
def setup(rank, world_size):
@@ -65,7 +64,9 @@ def cleanup():
6564
def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
6665
m = nn.Sequential(
6766
nn.Linear(K, N, dtype=base_dtype),
67+
nn.ReLU(),
6868
nn.Linear(N, N, dtype=base_dtype),
69+
nn.ReLU(),
6970
)
7071
if is_fp8:
7172
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
@@ -78,31 +79,52 @@ def fsdp_main(rank, world_size, args):
7879
setup(rank, world_size)
7980
torch.cuda.set_device(rank)
8081

81-
is_fp8, emulate, base_dtype = args
82+
# TODO: We set fullgraph as an option. However, it currently doesn't work for fullgraph compile.
83+
# We can investigate and fix it later.
84+
is_fp8, emulate, base_dtype, compile, fullgraph = args
8285
model = get_model(K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype).to(
8386
rank
8487
)
8588
model.load_state_dict(torch.load(sd_in_fname))
86-
model = FSDP(model)
89+
# To compile FSDP, we need use_orig_params to True
90+
model = FSDP(model, use_orig_params=True)
91+
# TODO: The following line doesn't work. We should fix it.
92+
# model = FSDP(torch.compile(model), use_orig_params=True)
93+
8794
# Note: we need to multiply by world_size here to match single GPU
8895
# optimizer update
8996
optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size)
9097

9198
ref_input_global = torch.load(input_fname).to(base_dtype)
9299

93100
# basic distributed data sampling
94-
bsz_global = ref_input_global.shape[0]
95101
assert B % world_size == 0
96102
bsz_local_start = int(rank / world_size * B)
97103
bsz_local_end = int((rank + 1) / world_size * B)
98104
ref_input_local = ref_input_global[bsz_local_start:bsz_local_end].to(rank)
99105

100-
for _ in range(N_ITER):
106+
sync_float8_func = sync_float8_amax_and_scale_history
107+
if compile:
108+
sync_float8_func = torch.compile(
109+
sync_float8_amax_and_scale_history, fullgraph=fullgraph
110+
)
111+
112+
def forward_backward(model):
101113
optimizer.zero_grad()
102114
y_local = model(ref_input_local)
103115
y_local.sum().backward()
104-
sync_float8_amax_and_scale_history(model)
116+
sync_float8_func(model)
105117
optimizer.step()
118+
return y_local
119+
120+
for iter in range(N_ITER):
121+
# We first run one iteration without compile, as a workaround to compile float8 layer.
122+
# In the first iter, float8 layers go to the branches of "self.is_amax_initialized == False"
123+
# After that, float8 layers go the the branches of "self.is_amax_initialized == True"
124+
# TODO: Need to fix compile to run wihtout this workaround.
125+
if iter == 1 and compile:
126+
model = torch.compile(model, fullgraph=fullgraph)
127+
y_local = forward_backward(model)
106128

107129
# get global y
108130
y_global = [
@@ -126,7 +148,7 @@ def fsdp_main(rank, world_size, args):
126148
cleanup()
127149

128150

129-
def run(mode: str, is_fp8: bool):
151+
def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = False):
130152
print(f"Mode: {mode}".center(100, "-"))
131153
base_dtype = torch.bfloat16
132154
if not os.path.exists(data_dir):
@@ -160,19 +182,24 @@ def run(mode: str, is_fp8: bool):
160182
model.load_state_dict(torch.load(sd_in_fname))
161183
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
162184

163-
for _ in range(N_ITER):
185+
def forward_backward():
164186
optimizer.zero_grad()
165187
y = model(ref_input)
166188
y.sum().backward()
167189
sync_float8_amax_and_scale_history(model)
168190
optimizer.step()
191+
return y
192+
193+
for _ in range(N_ITER):
194+
y = forward_backward()
169195

170196
torch.save(y, output_single_gpu_fname)
171197
torch.save(model.state_dict(), sd_out_single_gpu_fname)
172198

173199
elif mode == "fsdp":
174200
WORLD_SIZE = torch.cuda.device_count()
175-
args = (is_fp8, emulate, base_dtype)
201+
# We only compile for fsdp, and compare the numerics with signle-gpu no-compile
202+
args = (is_fp8, emulate, base_dtype, compile_fsdp, fullgraph)
176203
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
177204

178205
elif mode == "analyze":

test/test_fsdp.sh

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
set -e
55

66
launch() {
7-
echo "launching IS_FP8 $IS_FP8"
7+
echo "launching IS_FP8 $IS_FP8, compile_fsdp $COMPILE, fullgraph $FULLGRAPH"
88

99
# generate the test data
10-
python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8
10+
python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
1111
echo "Success: ✅"
1212

1313
# generate single GPU model output and updated state dict
14-
python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8
14+
python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
1515
echo "Success: ✅"
1616

1717
# generate FSDP model output and updated state dict
@@ -20,16 +20,19 @@ launch() {
2020
# the NCCL_NET setting is to work around transient issues on a
2121
# specific host (`devgpu001.nha2`)
2222
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 NCCL_NET=SOCKET python test/test_fsdp.py \
23-
--mode fsdp --is_fp8 $IS_FP8
23+
--mode fsdp --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
2424

2525
# compare the outputs and state dicts and verify equivalence
26-
python test/test_fsdp.py --mode analyze --is_fp8 $IS_FP8
26+
python test/test_fsdp.py --mode analyze --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
2727
echo "Success: ✅"
2828

2929
echo "✅ All Tests Passed ✅"
3030
}
3131

32-
for IS_FP8 in False True
32+
# IS_FP8, COMPILE, FULLGRAPH
33+
for i in False,False,False True,False,False True,True,False
3334
do
35+
IFS=","; set -- $i;
36+
IS_FP8=$1; COMPILE=$2; FULLGRAPH=$3
3437
launch
3538
done

0 commit comments

Comments
 (0)