1919 Float8Tensor ,
2020 merge_mm_configs ,
2121 ScaledMMConfig ,
22+ ScalingGranularity ,
2223 tensor_already_casted_to_fp8 ,
2324 to_fp8_no_autograd ,
2425)
@@ -36,21 +37,26 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
3637 @staticmethod
3738 def forward (
3839 ctx ,
39- tensor ,
40+ tensor : torch . Tensor ,
4041 mm_config : ScaledMMConfig ,
42+ scaling_granularity : ScalingGranularity ,
4143 ):
4244 ctx .mm_config = mm_config
45+ ctx .scaling_granularity = scaling_granularity
4346 return tensor
4447
4548 @staticmethod
46- def backward (ctx , gradY ):
49+ def backward (ctx , gradY : torch . Tensor ):
4750 if tensor_already_casted_to_fp8 (gradY ):
48- return gradY , None
49- gradY_scale = tensor_to_scale (gradY , e5m2_dtype )
51+ return gradY , None , None
52+ gradY_scale = tensor_to_scale (gradY , e5m2_dtype , ctx . scaling_granularity )
5053 fp8_tensor = to_fp8_no_autograd (
51- gradY , gradY_scale , e5m2_dtype , mm_config = ctx .mm_config
54+ gradY ,
55+ gradY_scale ,
56+ e5m2_dtype ,
57+ mm_config = ctx .mm_config ,
5258 )
53- return fp8_tensor , None
59+ return fp8_tensor , None , None
5460
5561
5662class Float8DynamicLinear (torch .nn .Linear ):
@@ -63,13 +69,19 @@ def __init__(self, **super_kwargs):
6369 super ().__init__ (** super_kwargs )
6470
6571 def forward (self , input : torch .Tensor ) -> torch .Tensor :
66- x_fp8 = cast_to_float8_e4m3_dynamic (input , self .forward_config )
72+ x_fp8 = cast_to_float8_e4m3_dynamic (
73+ input , self .forward_config , self .scaling_granularity
74+ )
6775 if isinstance (self .weight , Float8Tensor ): # cast by FSDP
6876 w_fp8 = self .weight
6977 else :
70- w_fp8 = cast_to_float8_e4m3_dynamic (self .weight , self .forward_config )
78+ w_fp8 = cast_to_float8_e4m3_dynamic (
79+ self .weight , self .forward_config , self .scaling_granularity
80+ )
7181 y = torch .nn .functional .linear (x_fp8 , w_fp8 , self .bias )
72- y = cast_to_float8_e5m2_dynamic_bw (y , self .backward_config )
82+ y = cast_to_float8_e5m2_dynamic_bw (
83+ y , self .backward_config , self .scaling_granularity
84+ )
7385 return y
7486
7587 @classmethod
@@ -101,9 +113,14 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
101113 fp8_output = False ,
102114 pad_inner_dim = config .pad_inner_dim ,
103115 )
116+ # TODO: For now hardcode TensorWise scaling
117+ new_mod .scaling_granularity = ScalingGranularity .TensorWise
118+
104119 if config .enable_fsdp_fp8_all_gather :
105120 new_mod .weight = nn .Parameter (
106- WeightWithDynamicFloat8CastTensor (mod .weight , new_mod .forward_config )
121+ WeightWithDynamicFloat8CastTensor (
122+ mod .weight , new_mod .forward_config , new_mod .scaling_granularity
123+ )
107124 )
108125 else :
109126 new_mod .weight = mod .weight
@@ -112,18 +129,31 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
112129
113130
114131def cast_to_float8_e4m3_dynamic (
115- inpt_tensor : torch .Tensor , mm_config : ScaledMMConfig , reduce_amax : bool = False
132+ inpt_tensor : torch .Tensor ,
133+ mm_config : ScaledMMConfig ,
134+ scaling_granularity : ScalingGranularity ,
135+ reduce_amax : bool = False ,
116136) -> Float8Tensor :
117137 if tensor_already_casted_to_fp8 (inpt_tensor ):
118138 return inpt_tensor
119- scale = tensor_to_scale (inpt_tensor , e4m3_dtype , reduce_amax )
120- return Float8Tensor .to_float8 (inpt_tensor , scale , e4m3_dtype , mm_config = mm_config )
139+ scale = tensor_to_scale (
140+ inpt_tensor , e4m3_dtype , scaling_granularity , reduce_amax = reduce_amax
141+ )
142+ return Float8Tensor .to_float8 (
143+ inpt_tensor ,
144+ scale ,
145+ e4m3_dtype ,
146+ mm_config = mm_config ,
147+ scaling_granularity = scaling_granularity ,
148+ )
121149
122150
123151def cast_to_float8_e5m2_dynamic_bw (
124- gradY : torch .Tensor , mm_config : ScaledMMConfig
152+ gradY : torch .Tensor ,
153+ mm_config : ScaledMMConfig ,
154+ scaling_granularity : ScalingGranularity ,
125155) -> torch .Tensor :
126- return NoopFwToFloat8E5M2Bw .apply (gradY , mm_config )
156+ return NoopFwToFloat8E5M2Bw .apply (gradY , mm_config , scaling_granularity )
127157
128158
129159# FSDP pads its local tensor on dim-0. The subclass should be preserved such
@@ -143,7 +173,12 @@ def cast_to_float8_e5m2_dynamic_bw(
143173
144174class WeightWithDynamicFloat8CastTensor (torch .Tensor ):
145175 @staticmethod
146- def __new__ (cls , tensor : torch .Tensor , mm_config : ScaledMMConfig ):
176+ def __new__ (
177+ cls ,
178+ tensor : torch .Tensor ,
179+ mm_config : ScaledMMConfig ,
180+ scaling_granularity : ScalingGranularity ,
181+ ):
147182 return torch .Tensor ._make_wrapper_subclass (
148183 cls ,
149184 tensor .size (),
@@ -157,24 +192,38 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
157192 requires_grad = tensor .requires_grad ,
158193 )
159194
160- def __init__ (self , tensor : torch .Tensor , mm_config : ScaledMMConfig ):
195+ def __init__ (
196+ self ,
197+ tensor : torch .Tensor ,
198+ mm_config : ScaledMMConfig ,
199+ scaling_granularity : ScalingGranularity ,
200+ ):
161201 self ._tensor = tensor
162202 self ._mm_config = mm_config
203+ self ._scaling_granularity = scaling_granularity
163204
164205 @classmethod
165206 def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
166207 if func == torch .ops .aten .detach .default :
167208 return WeightWithDynamicFloat8CastTensor (
168- args [0 ]._tensor , args [0 ]._mm_config
209+ args [0 ]._tensor , args [0 ]._mm_config , args [ 0 ]. _scaling_granularity
169210 )
170211 mm_config : Optional [ScaledMMConfig ] = None
212+ scaling_granularity : Optional [ScalingGranularity ] = None
171213
172214 def unwrap (t ):
173215 nonlocal mm_config
216+ nonlocal scaling_granularity
174217 if mm_config is None :
175218 mm_config = t ._mm_config
176219 else :
177220 mm_config = merge_mm_configs (mm_config , t ._mm_config )
221+
222+ if scaling_granularity is None :
223+ scaling_granularity = t ._scaling_granularity
224+ else :
225+ # TODO For now we assume that the scaling granularity is same across all tensors
226+ assert scaling_granularity == t ._scaling_granularity
178227 return t ._tensor
179228
180229 args , kwargs = pytree .tree_map_only (
@@ -184,23 +233,33 @@ def unwrap(t):
184233 if func not in _ops_to_preserve_subclass :
185234 return out
186235 return pytree .tree_map_only (
187- torch .Tensor , lambda x : WeightWithDynamicFloat8CastTensor (x , mm_config ), out
236+ torch .Tensor ,
237+ lambda x : WeightWithDynamicFloat8CastTensor (
238+ x , mm_config , scaling_granularity
239+ ),
240+ out ,
188241 )
189242
190243 def __tensor_flatten__ (self ):
191- return ["_tensor" ], self ._mm_config
244+ return ["_tensor" ], {
245+ "_mm_config" : self ._mm_config ,
246+ "_scaling_granularity" : self ._scaling_granularity ,
247+ }
192248
193249 @staticmethod
194250 def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
195- mm_config = flatten_spec
196- return WeightWithDynamicFloat8CastTensor (inner_tensors ["_tensor" ], mm_config )
251+ mm_config = flatten_spec ["_mm_config" ]
252+ scaling_granularity = flatten_spec ["_scaling_granularity" ]
253+ return WeightWithDynamicFloat8CastTensor (
254+ inner_tensors ["_tensor" ], mm_config , scaling_granularity
255+ )
197256
198257 def __repr__ (self ):
199- return f"WeightWithDynamicFloat8CastTensor(tensor={ self ._tensor } , mm_config={ self ._mm_config } )"
258+ return f"WeightWithDynamicFloat8CastTensor(tensor={ self ._tensor } , mm_config={ self ._mm_config } , scaling_granularity= { self . _scaling_granularity } )"
200259
201260 def fsdp_pre_all_gather (self , mesh ):
202261 float8_tensor = cast_to_float8_e4m3_dynamic (
203- self ._tensor , self ._mm_config , reduce_amax = True
262+ self ._tensor , self ._mm_config , self . _scaling_granularity , reduce_amax = True
204263 )
205264 return (float8_tensor ._data ,), (float8_tensor ._scale ,)
206265
0 commit comments