33#
44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
6+
7+ import copy
68import random
79from contextlib import nullcontext
810from dataclasses import dataclass , field
1214import fire
1315
1416import torch
17+ from float8_experimental .float8_dynamic_linear import Float8DynamicLinear
18+ from float8_experimental .float8_linear import Float8Linear
1519from float8_experimental .float8_linear_utils import (
1620 get_float8_linear ,
1721 linear_requires_sync ,
1822 LinearType ,
23+ swap_linear_with_float8_linear ,
1924 sync_float8_amax_and_scale_history ,
2025)
2126from torch .profiler import profile , ProfilerActivity , record_function
2227
2328
29+ class LNLinear (torch .nn .Module ):
30+ def __init__ (self , fc_dim1 , fc_dim2 ):
31+ super ().__init__ ()
32+ self .ln = torch .nn .LayerNorm (fc_dim1 , elementwise_affine = False )
33+ self .fc = torch .nn .Linear (fc_dim1 , fc_dim2 , bias = False )
34+
35+ def forward (self , x ):
36+ x = self .ln (x )
37+ x = self .fc (x )
38+ return x
39+
40+
2441@dataclass
2542class ProfileConfig :
2643 file_path : Optional [str ] = None
@@ -77,65 +94,58 @@ def profile_function(
7794
7895
7996@dataclass (frozen = True )
80- class LinearParams :
97+ class ModelParams :
8198 M : int
8299 K : int
83100 N : int
84- input_bias : bool
85101 ref_dtype : torch .dtype
86102 layer_norm : bool = True
87- torch_compile : Optional [bool ] = False
88103
89104
90- def main (profile_path : Path , compile : bool , linear_type : str ):
91- profile_path = Path (profile_path )
92- assert profile_path .is_dir (), f"Path { profile_path } must be a directory"
93- params = LinearParams (
105+ def main (
106+ profile_path_prefix : Path ,
107+ compile : bool = True ,
108+ linear_type : str = "dynamic" ,
109+ use_layer_norm : bool = False ,
110+ ):
111+ params = ModelParams (
94112 M = 4 * 4096 ,
95113 K = 8192 ,
96114 N = 7168 ,
97- input_bias = False ,
98115 ref_dtype = torch .bfloat16 ,
99- layer_norm = True ,
100- torch_compile = compile ,
116+ layer_norm = use_layer_norm ,
101117 )
102118 print (f"Compile is set to | { compile } " )
103119 print (f"Using Linear type: | { linear_type } " )
104120 print (f"Use layer norm is set to | { params .layer_norm } " )
105- linear_ref = torch .nn .Linear (
106- params .K ,
107- params .N ,
108- bias = params .input_bias ,
109- device = "cuda" ,
110- dtype = params .ref_dtype ,
111- )
121+
122+ device = "cuda"
123+ if params .layer_norm :
124+ m_ref = LNLinear (params .K , params .N )
125+ else :
126+ m_ref = torch .nn .Sequential (
127+ torch .nn .Linear (params .K , params .N , bias = False ),
128+ )
129+ m_ref = m_ref .to (device ).to (params .ref_dtype )
130+
112131 linear_type = LinearType [linear_type .upper ()]
113- linear_float8 = get_float8_linear (linear_type , linear_ref )
132+ linear_cls = (
133+ Float8Linear if linear_type is LinearType .DELAYED else Float8DynamicLinear
134+ )
135+
136+ m_float8 = copy .deepcopy (m_ref )
137+ swap_linear_with_float8_linear (m_float8 , linear_cls )
114138
115139 input_tensor = torch .randn (
116140 params .M , params .K , device = "cuda" , dtype = params .ref_dtype , requires_grad = True
117141 )
118142
119- if params .layer_norm :
120- ln = torch .nn .LayerNorm (
121- params .K , elementwise_affine = False , device = "cuda" , dtype = params .ref_dtype
122- )
123-
124143 def ref_forw_backward (x ):
125- if params .layer_norm :
126- with record_function ("layer_norm" ):
127- x = ln (x )
128- with record_function ("forward" ):
129- out = linear_ref (x )
130- with record_function ("backward" ):
131- out .sum ().backward ()
144+ out = m_ref (x )
145+ out .sum ().backward ()
132146
133- def float8_forw_backward (x ):
134- if params .layer_norm :
135- with record_function ("layer_norm" ):
136- x = ln (x )
137- with record_function ("forward" ):
138- out = linear_float8 (x )
147+ def float8_forw (x ):
148+ out = m_float8 (x )
139149 return out
140150
141151 def float8_forw_backward_wrapper (x ):
@@ -146,34 +156,34 @@ def float8_forw_backward_wrapper(x):
146156 # TODO(future): make this better
147157 if linear_requires_sync (linear_type ):
148158 with record_function ("scale_amax_and_scales" ):
149- sync_float8_amax_and_scale_history (linear_float8 )
150- out = float8_forw_backward (x )
159+ sync_float8_amax_and_scale_history (m_float8 )
160+ out = float8_forw (x )
151161
152162 # out.sum().backward() is also not torch.compile fullgraph
153163 # friendly
154164 with record_function ("backward" ):
155165 out .sum ().backward ()
156166
157- if params . torch_compile :
167+ if compile :
158168 ref_forw_backward = torch .compile (ref_forw_backward )
159- float8_forw_backward = torch .compile (float8_forw_backward , fullgraph = True )
169+ float8_forw = torch .compile (float8_forw , fullgraph = True )
160170
161171 for _ in range (5 ):
162172 ref_forw_backward (input_tensor )
163173 float8_forw_backward_wrapper (input_tensor )
164174
165- # Profile Reference Linear
166- ref_string = f"linear_ref_dtype_ { params . ref_dtype } _M_ { params . M } _K_ { params . K } _N_ { params . N } _input_bias_ { params . input_bias } _compile_ { params . torch_compile } .json"
175+ # Profile Reference Model
176+ ref_suffix = f"_ref_compile_ { compile } .json"
167177 profile_config = ProfileConfig (
168- str ( profile_path / ref_string ), ref_string , iters = 5 , warmup_iters = 5 , sync = True
178+ profile_path_prefix + ref_suffix , ref_suffix , iters = 5 , warmup_iters = 5 , sync = True
169179 )
170180 profile_function (profile_config , ref_forw_backward , input_tensor )
171181
172- # # Profile Float8 Linear
173- float8_string = f"linear_float8_M_ { params . M } _K_ { params . K } _N_ { params . N } _input_bias_ { params . input_bias } _compile_ { params . torch_compile } _{ linear_type } .json"
182+ # Profile Float8 Model
183+ float8_suffix = f"_float8_compile_ { compile } _{ linear_type } .json"
174184 profile_config = ProfileConfig (
175- str ( profile_path / float8_string ) ,
176- float8_string ,
185+ profile_path_prefix + float8_suffix ,
186+ float8_suffix ,
177187 iters = 5 ,
178188 warmup_iters = 5 ,
179189 sync = True ,
@@ -182,7 +192,7 @@ def float8_forw_backward_wrapper(x):
182192
183193
184194def invoke_main () -> None :
185- # Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles --compile=True --linear_type="dynamic"
195+ # Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles/current_profile --compile=True --linear_type="dynamic"
186196 fire .Fire (main )
187197
188198
0 commit comments