@@ -61,15 +61,22 @@ def cleanup():
6161 dist .destroy_process_group ()
6262
6363
64- def get_model (K , N , is_fp8 , emulate , base_dtype = torch .float32 ):
64+ def get_model (
65+ K , N , is_fp8 , emulate , base_dtype = torch .float32 , recompute_weight_cast : bool = False
66+ ):
6567 m = nn .Sequential (
6668 nn .Linear (K , N , dtype = base_dtype ),
6769 nn .ReLU (),
6870 nn .Linear (N , N , dtype = base_dtype ),
6971 nn .ReLU (),
7072 )
7173 if is_fp8 :
72- swap_linear_with_float8_linear (m , Float8Linear , emulate = emulate )
74+ swap_linear_with_float8_linear (
75+ m ,
76+ Float8Linear ,
77+ emulate = emulate ,
78+ recompute_weight_cast = recompute_weight_cast ,
79+ )
7380 return m
7481
7582
@@ -81,10 +88,15 @@ def fsdp_main(rank, world_size, args):
8188
8289 # TODO: We set fullgraph as an option. However, it currently doesn't work for fullgraph compile.
8390 # We can investigate and fix it later.
84- is_fp8 , emulate , base_dtype , compile , fullgraph = args
85- model = get_model (K , N , is_fp8 = is_fp8 , emulate = emulate , base_dtype = base_dtype ).to (
86- rank
87- )
91+ is_fp8 , emulate , base_dtype , compile , fullgraph , recompute_weight_cast = args
92+ model = get_model (
93+ K ,
94+ N ,
95+ is_fp8 = is_fp8 ,
96+ emulate = emulate ,
97+ base_dtype = base_dtype ,
98+ recompute_weight_cast = recompute_weight_cast ,
99+ ).to (rank )
88100 model .load_state_dict (torch .load (sd_in_fname ))
89101 # To compile FSDP, we need use_orig_params to True
90102 model = FSDP (model , use_orig_params = True )
@@ -148,7 +160,13 @@ def forward_backward(model):
148160 cleanup ()
149161
150162
151- def run (mode : str , is_fp8 : bool , compile_fsdp : bool = False , fullgraph : bool = False ):
163+ def run (
164+ mode : str ,
165+ is_fp8 : bool ,
166+ compile_fsdp : bool = False ,
167+ fullgraph : bool = False ,
168+ recompute_weight_cast : bool = False ,
169+ ):
152170 print (f"Mode: { mode } " .center (100 , "-" ))
153171 base_dtype = torch .bfloat16
154172 if not os .path .exists (data_dir ):
@@ -169,15 +187,25 @@ def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = F
169187 # generate reference input
170188 ref_input = torch .randn (B , M , K ).cuda ().to (base_dtype )
171189 model = get_model (
172- K , N , is_fp8 = is_fp8 , emulate = emulate , base_dtype = base_dtype
190+ K ,
191+ N ,
192+ is_fp8 = is_fp8 ,
193+ emulate = emulate ,
194+ base_dtype = base_dtype ,
195+ recompute_weight_cast = recompute_weight_cast ,
173196 ).cuda ()
174197 torch .save (ref_input , input_fname )
175198 torch .save (model .state_dict (), sd_in_fname )
176199
177200 elif mode == "single_gpu" :
178201 ref_input = torch .load (input_fname ).to (base_dtype )
179202 model = get_model (
180- K , N , is_fp8 = is_fp8 , emulate = emulate , base_dtype = base_dtype
203+ K ,
204+ N ,
205+ is_fp8 = is_fp8 ,
206+ emulate = emulate ,
207+ base_dtype = base_dtype ,
208+ recompute_weight_cast = recompute_weight_cast ,
181209 ).cuda ()
182210 model .load_state_dict (torch .load (sd_in_fname ))
183211 optimizer = torch .optim .SGD (model .parameters (), lr = lr )
@@ -199,7 +227,14 @@ def forward_backward():
199227 elif mode == "fsdp" :
200228 WORLD_SIZE = torch .cuda .device_count ()
201229 # We only compile for fsdp, and compare the numerics with signle-gpu no-compile
202- args = (is_fp8 , emulate , base_dtype , compile_fsdp , fullgraph )
230+ args = (
231+ is_fp8 ,
232+ emulate ,
233+ base_dtype ,
234+ compile_fsdp ,
235+ fullgraph ,
236+ recompute_weight_cast ,
237+ )
203238 mp .spawn (fsdp_main , args = (WORLD_SIZE , args ), nprocs = WORLD_SIZE , join = True )
204239
205240 elif mode == "analyze" :
0 commit comments