2525 tensor_already_casted_to_fp8 ,
2626 to_fp8_no_autograd ,
2727)
28- from float8_experimental .float8_utils import e4m3_dtype , tensor_to_scale
28+ from float8_experimental .float8_utils import (
29+ e4m3_dtype ,
30+ get_supported_granularity ,
31+ tensor_to_scale ,
32+ )
33+
34+ SUPPORTED_GRANULARITY = get_supported_granularity ()
2935
3036
3137class ActivationCasting (Enum ):
@@ -75,7 +81,7 @@ def __init__(
7581 # FP8 specific arguments
7682 quant_config : QuantConfig ,
7783 forward_config : ScaledMMConfig ,
78- scaling_granularity : ScalingGranularity ,
84+ scaling_granularity : Optional [ ScalingGranularity ] ,
7985 # nn.Linear arguments
8086 in_features : int ,
8187 out_features : int ,
@@ -86,7 +92,26 @@ def __init__(
8692 # Construct the superclass this will create dummy weights and biases
8793 super ().__init__ (in_features , out_features , bias , device , dtype )
8894 self .forward_config = forward_config
89- self .scaling_granularity = scaling_granularity
95+ if scaling_granularity is None :
96+ self .scaling_granularity = (
97+ ScalingGranularity .AxisWise
98+ if dtype == torch .bfloat16
99+ and quant_config .static_quantization_scale is None
100+ else ScalingGranularity .TensorWise
101+ )
102+ else :
103+ assert (
104+ scaling_granularity in SUPPORTED_GRANULARITY
105+ ), f"scaling_granularity must be in { SUPPORTED_GRANULARITY } but got { scaling_granularity } "
106+ if (
107+ scaling_granularity == ScalingGranularity .AxisWise
108+ and dtype != torch .bfloat16
109+ ):
110+ raise ValueError (
111+ "AxisWise scaling granularity is only supported for bfloat16."
112+ )
113+ self .scaling_granularity = scaling_granularity
114+
90115 self .activation_casting = quant_config .activation_casting
91116 if self .activation_casting == ActivationCasting .STATIC :
92117 self .register_buffer (
@@ -101,13 +126,19 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
101126 input , self .weight .to_original_precision ()
102127 )
103128
129+ # TODO we arent folding leading dims yet, but need it to calculate the proper scale.. this sucks
130+ original_m = input .shape [:- 1 ]
131+ input = input .view (- 1 , input .shape [- 1 ])
132+
104133 x_fp8 = cast_to_float8_e4m3_inference (
105134 input ,
106135 self .forward_config ,
107136 static_quantization_scale = self .static_quantization_scale ,
108137 scaling_granularity = self .scaling_granularity ,
109138 )
110- return torch .nn .functional .linear (x_fp8 , self .weight , self .bias )
139+ return torch .nn .functional .linear (x_fp8 , self .weight , self .bias ).view (
140+ * original_m , - 1
141+ )
111142
112143 # Builder functions for Float8LinearInference
113144 def quantize_weight (self , dtype : torch .dtype = e4m3_dtype ) -> None :
@@ -124,7 +155,12 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
124155 assert not isinstance (
125156 self .weight , Float8Tensor
126157 ), "Weight has already been quantized, cannot quantize again."
127- scale = tensor_to_scale (self .weight , dtype , self .scaling_granularity )
158+
159+ # For weight tensors + AxisWise we calculate scales along columns
160+ dim = None
161+ if self .scaling_granularity == ScalingGranularity .AxisWise :
162+ dim = 1
163+ scale = tensor_to_scale (self .weight , dtype , self .scaling_granularity , dim = dim )
128164 quantized_weight = to_fp8_no_autograd (
129165 self .weight , scale , dtype , self .forward_config
130166 )
@@ -143,19 +179,20 @@ def from_float(
143179 module : nn .Module ,
144180 quant_config : QuantConfig ,
145181 use_fast_accum : bool ,
182+ scaling_granularity : Optional [ScalingGranularity ],
146183 ) -> "Float8InferenceLinear" :
147184 """
148185 Create an nn.Linear with fp8 compute from another nn.Linear
149186
150187 Args:
151188 mod (torch.nn.Linear): nn.Linear to convert
152189 quant_config (QuantConfig): Configuration for the weight and activation casting
190+ use_fast_accum (bool): Whether to enable fast accumulation for the Float8InferenceLinear.
191+ scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
153192 """
154193 forward_config = ScaledMMConfig (
155194 False , use_fast_accum , pad_inner_dim = config .pad_inner_dim
156195 )
157- # TODO: For now hardcode TensorWise scaling
158- scaling_granularity = ScalingGranularity .TensorWise
159196 linear = cls (
160197 quant_config ,
161198 forward_config ,
@@ -164,6 +201,7 @@ def from_float(
164201 module .out_features ,
165202 False ,
166203 device = torch .device ("meta" ),
204+ dtype = module .weight .dtype ,
167205 )
168206 linear .set_weight_and_bias (module .weight , module .bias )
169207 linear .quantize_weight ()
@@ -194,18 +232,29 @@ def cast_to_float8_e4m3_inference(
194232 """
195233 if tensor_already_casted_to_fp8 (inpt_tensor ):
196234 return inpt_tensor
235+
236+ # For input tensors + AxisWise we calculate scales along rows
237+ dim = None
238+ if scaling_granularity == ScalingGranularity .AxisWise :
239+ dim = 1
240+
197241 scale = (
198242 static_quantization_scale
199243 if static_quantization_scale is not None
200244 else tensor_to_scale (
201- inpt_tensor , e4m3_dtype , scaling_granularity , reduce_amax = reduce_amax
245+ inpt_tensor ,
246+ e4m3_dtype ,
247+ scaling_granularity ,
248+ dim = dim ,
249+ reduce_amax = reduce_amax ,
202250 )
203251 )
204252 return Float8Tensor .to_float8 (
205253 inpt_tensor ,
206254 scale ,
207255 e4m3_dtype ,
208256 mm_config = mm_config ,
257+ scaling_granularity = scaling_granularity ,
209258 )
210259
211260
@@ -215,6 +264,7 @@ def quantize_to_float8(
215264 * ,
216265 skip_fqn_list : Optional [List [str ]] = None ,
217266 use_fast_accum : bool = True ,
267+ scaling_granularity : Optional [ScalingGranularity ] = None ,
218268) -> Optional [nn .Module ]:
219269 """
220270 Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
@@ -228,6 +278,7 @@ def quantize_to_float8(
228278 quant_config (QuantConfig): Quantization configuration for Float8 conversion.
229279 skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
230280 use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
281+ scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
231282
232283 Returns:
233284 nn.Module: The modified module with applicable Linear layers converted to Float8.
@@ -237,6 +288,8 @@ def quantize_to_float8(
237288 """
238289 return swap_linear_layers (
239290 module ,
240- lambda m : Float8InferenceLinear .from_float (m , quant_config , use_fast_accum ),
291+ lambda m : Float8InferenceLinear .from_float (
292+ m , quant_config , use_fast_accum , scaling_granularity
293+ ),
241294 skip_fqn_list = skip_fqn_list ,
242295 )
0 commit comments