1919 Float8Tensor ,
2020 merge_mm_configs ,
2121 ScaledMMConfig ,
22+ ScalingStrategy ,
2223 tensor_already_casted_to_fp8 ,
2324 to_fp8_no_autograd ,
2425)
@@ -36,21 +37,27 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
3637 @staticmethod
3738 def forward (
3839 ctx ,
39- tensor ,
40+ tensor : torch . Tensor ,
4041 mm_config : ScaledMMConfig ,
42+ scaling_strategy : ScalingStrategy ,
4143 ):
4244 ctx .mm_config = mm_config
45+ ctx .scaling_strategy = scaling_strategy
4346 return tensor
4447
4548 @staticmethod
46- def backward (ctx , gradY ):
49+ def backward (ctx , gradY : torch . Tensor ):
4750 if tensor_already_casted_to_fp8 (gradY ):
48- return gradY , None
51+ return gradY , None , None
4952 gradY_scale = tensor_to_scale (gradY , e5m2_dtype )
5053 fp8_tensor = to_fp8_no_autograd (
51- gradY , gradY_scale , e5m2_dtype , mm_config = ctx .mm_config
54+ gradY ,
55+ gradY_scale ,
56+ e5m2_dtype ,
57+ mm_config = ctx .mm_config ,
58+ scaling_strategy = ctx .scaling_strategy ,
5259 )
53- return fp8_tensor , None
60+ return fp8_tensor , None , None
5461
5562
5663class Float8DynamicLinear (torch .nn .Linear ):
@@ -63,13 +70,15 @@ def __init__(self, **super_kwargs):
6370 super ().__init__ (** super_kwargs )
6471
6572 def forward (self , input : torch .Tensor ) -> torch .Tensor :
66- x_fp8 = cast_to_float8_e4m3fn (input , self .forward_config )
73+ x_fp8 = cast_to_float8_e4m3fn (input , self .forward_config , self . scaling_strategy )
6774 if isinstance (self .weight , Float8Tensor ): # cast by FSDP
6875 w_fp8 = self .weight
6976 else :
70- w_fp8 = cast_to_float8_e4m3fn (self .weight , self .forward_config )
77+ w_fp8 = cast_to_float8_e4m3fn (
78+ self .weight , self .forward_config , self .scaling_strategy
79+ )
7180 y = torch .nn .functional .linear (x_fp8 , w_fp8 , self .bias )
72- y = cast_to_float8_e5m2_bw (y , self .backward_config )
81+ y = cast_to_float8_e5m2_bw (y , self .backward_config , self . scaling_strategy )
7382 return y
7483
7584 @classmethod
@@ -101,9 +110,14 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
101110 fp8_output = False ,
102111 pad_inner_dim = config .pad_inner_dim ,
103112 )
113+ # TODO: For now hardcode TensorWise scaling
114+ new_mod .scaling_strategy = ScalingStrategy .TensorWise
115+
104116 if config .enable_fsdp_fp8_all_gather :
105117 new_mod .weight = nn .Parameter (
106- WeightWithDynamicFloat8CastTensor (mod .weight , new_mod .forward_config )
118+ WeightWithDynamicFloat8CastTensor (
119+ mod .weight , new_mod .forward_config , new_mod .scaling_strategy
120+ )
107121 )
108122 else :
109123 new_mod .weight = mod .weight
@@ -112,18 +126,27 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
112126
113127
114128def cast_to_float8_e4m3fn (
115- inpt_tensor : torch .Tensor , mm_config : ScaledMMConfig , reduce_amax : bool = False
129+ inpt_tensor : torch .Tensor ,
130+ mm_config : ScaledMMConfig ,
131+ scaling_strategy : ScalingStrategy ,
132+ reduce_amax : bool = False ,
116133) -> Float8Tensor :
117134 if tensor_already_casted_to_fp8 (inpt_tensor ):
118135 return inpt_tensor
119136 scale = tensor_to_scale (inpt_tensor , e4m3_dtype , reduce_amax )
120- return Float8Tensor .to_float8 (inpt_tensor , scale , e4m3_dtype , mm_config = mm_config )
137+ return Float8Tensor .to_float8 (
138+ inpt_tensor ,
139+ scale ,
140+ e4m3_dtype ,
141+ mm_config = mm_config ,
142+ scaling_strategy = scaling_strategy ,
143+ )
121144
122145
123146def cast_to_float8_e5m2_bw (
124- gradY : torch .Tensor , mm_config : ScaledMMConfig
147+ gradY : torch .Tensor , mm_config : ScaledMMConfig , scaling_strategy : ScalingStrategy
125148) -> torch .Tensor :
126- return NoopFwToFloat8E5M2Bw .apply (gradY , mm_config )
149+ return NoopFwToFloat8E5M2Bw .apply (gradY , mm_config , scaling_strategy )
127150
128151
129152# FSDP pads its local tensor on dim-0. The subclass should be preserved such
@@ -143,7 +166,12 @@ def cast_to_float8_e5m2_bw(
143166
144167class WeightWithDynamicFloat8CastTensor (torch .Tensor ):
145168 @staticmethod
146- def __new__ (cls , tensor : torch .Tensor , mm_config : ScaledMMConfig ):
169+ def __new__ (
170+ cls ,
171+ tensor : torch .Tensor ,
172+ mm_config : ScaledMMConfig ,
173+ scaling_strategy : ScalingStrategy ,
174+ ):
147175 return torch .Tensor ._make_wrapper_subclass (
148176 cls ,
149177 tensor .size (),
@@ -157,24 +185,38 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
157185 requires_grad = tensor .requires_grad ,
158186 )
159187
160- def __init__ (self , tensor : torch .Tensor , mm_config : ScaledMMConfig ):
188+ def __init__ (
189+ self ,
190+ tensor : torch .Tensor ,
191+ mm_config : ScaledMMConfig ,
192+ scaling_strategy : ScalingStrategy ,
193+ ):
161194 self ._tensor = tensor
162195 self ._mm_config = mm_config
196+ self ._scaling_strategy = scaling_strategy
163197
164198 @classmethod
165199 def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
166200 if func == torch .ops .aten .detach .default :
167201 return WeightWithDynamicFloat8CastTensor (
168- args [0 ]._tensor , args [0 ]._mm_config
202+ args [0 ]._tensor , args [0 ]._mm_config , args [ 0 ]. _scaling_strategy
169203 )
170204 mm_config : Optional [ScaledMMConfig ] = None
205+ scaling_strategy : Optional [ScalingStrategy ] = None
171206
172207 def unwrap (t ):
173208 nonlocal mm_config
209+ nonlocal scaling_strategy
174210 if mm_config is None :
175211 mm_config = t ._mm_config
176212 else :
177213 mm_config = merge_mm_configs (mm_config , t ._mm_config )
214+
215+ if scaling_strategy is None :
216+ scaling_strategy = t ._scaling_strategy
217+ else :
218+ # TODO For now we assume that the scaling strategy is same across all tensors
219+ assert scaling_strategy == t ._scaling_strategy
178220 return t ._tensor
179221
180222 args , kwargs = pytree .tree_map_only (
@@ -184,23 +226,31 @@ def unwrap(t):
184226 if func not in _ops_to_preserve_subclass :
185227 return out
186228 return pytree .tree_map_only (
187- torch .Tensor , lambda x : WeightWithDynamicFloat8CastTensor (x , mm_config ), out
229+ torch .Tensor ,
230+ lambda x : WeightWithDynamicFloat8CastTensor (x , mm_config , scaling_strategy ),
231+ out ,
188232 )
189233
190234 def __tensor_flatten__ (self ):
191- return ["_tensor" ], self ._mm_config
235+ return ["_tensor" ], {
236+ "_mm_config" : self ._mm_config ,
237+ "_scaling_strategy" : self ._scaling_strategy ,
238+ }
192239
193240 @staticmethod
194241 def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
195- mm_config = flatten_spec
196- return WeightWithDynamicFloat8CastTensor (inner_tensors ["_tensor" ], mm_config )
242+ mm_config = flatten_spec ["_mm_config" ]
243+ scaling_strategy = flatten_spec ["_scaling_strategy" ]
244+ return WeightWithDynamicFloat8CastTensor (
245+ inner_tensors ["_tensor" ], mm_config , scaling_strategy
246+ )
197247
198248 def __repr__ (self ):
199- return f"WeightWithDynamicFloat8CastTensor(tensor={ self ._tensor } , mm_config={ self ._mm_config } )"
249+ return f"WeightWithDynamicFloat8CastTensor(tensor={ self ._tensor } , mm_config={ self ._mm_config } , scaling_strategy= { self . _scaling_strategy } )"
200250
201251 def fsdp_pre_all_gather (self , mesh ):
202252 float8_tensor = cast_to_float8_e4m3fn (
203- self ._tensor , self ._mm_config , reduce_amax = True
253+ self ._tensor , self ._mm_config , self . _scaling_strategy , reduce_amax = True
204254 )
205255 return (float8_tensor ._data ,), (float8_tensor ._scale ,)
206256
@@ -218,4 +268,6 @@ def fsdp_post_all_gather(
218268 assert isinstance (out , Float8Tensor ), f"{ type (out )} "
219269 out ._scale = scale
220270 return
221- return Float8Tensor (data , scale , param_dtype , self ._mm_config ), (data ,)
271+ return Float8Tensor (
272+ data , scale , param_dtype , self ._mm_config , self ._scaling_strategy
273+ ), (data ,)
0 commit comments