diff --git a/aqt/jax/v2/aqt_dot_general.py b/aqt/jax/v2/aqt_dot_general.py index 404991db..8cbc4381 100644 --- a/aqt/jax/v2/aqt_dot_general.py +++ b/aqt/jax/v2/aqt_dot_general.py @@ -1101,9 +1101,6 @@ def dg_core_vjp_fwd( assert ( lhs.dtype == rhs.dtype ), f'Unmatched lhs and rhs dtype: {lhs.dtype} vs {rhs.dtype}' - cfg.fwd.dg_quantizer.init_calibration() - cfg.dlhs.dg_quantizer.init_calibration() - cfg.drhs.dg_quantizer.init_calibration() ret, res = cfg.fwd( lhs, rhs, diff --git a/aqt/jax/v2/examples/flax_e2e_model_test.py b/aqt/jax/v2/examples/flax_e2e_model_test.py index 977b3cdc..26c8dcb7 100644 --- a/aqt/jax/v2/examples/flax_e2e_model_test.py +++ b/aqt/jax/v2/examples/flax_e2e_model_test.py @@ -23,8 +23,10 @@ from aqt.jax.v2 import config from aqt.jax.v2 import utils from aqt.jax.v2.examples import flax_e2e_model +from aqt.jax.v2.flax import aqt_flax from aqt.jax.v2.flax import aqt_flax_calibration from aqt.jax.v2.flax import delayed_scaling_calibration +from flax import linen as nn import jax import jax.numpy as jnp import numpy as np @@ -959,6 +961,46 @@ def forward(model, apply_fn): logits_after_conversion, _ = forward(model_serving, serve_fn) assert (logits_before_conversion == logits_after_conversion).all() + def test_simple(self): + aqt_cfg_dg = config.config_v4() + amax_history_length = 32 + calibration_cls = functools.partial( + delayed_scaling_calibration.DelayedScalingCalibration, + amax_history_length=amax_history_length, + ) + aqt_cfg_dg.fwd.dg_quantizer.lhs.calibration = calibration_cls + aqt_cfg_dg.fwd.dg_quantizer.rhs.calibration = calibration_cls + aqt_cfg_dg.dlhs.dg_quantizer.lhs.calibration = calibration_cls + aqt_cfg_dg.dlhs.dg_quantizer.rhs.calibration = calibration_cls + aqt_cfg_dg.drhs.dg_quantizer.lhs.calibration = calibration_cls + aqt_cfg_dg.drhs.dg_quantizer.rhs.calibration = calibration_cls + + class MlpBlock(nn.Module): + + @nn.compact + def __call__(self, inputs): + dot_general = aqt_flax.AqtDotGeneral(aqt_cfg_dg) + x = nn.Dense(dot_general=dot_general, features=1)(inputs) + return x + + x = jnp.ones((10, 10)) + y = jnp.ones((10, 1)) + + model = MlpBlock() + params = model.init({"params": jax.random.PRNGKey(0)}, x) + + def loss_fn(params): + return jnp.mean( + jnp.abs( + y - model.apply(params, x, rngs={"params": jax.random.PRNGKey(0)}) + ) + ) + + grad_fn = jax.grad(loss_fn) + with jax.checking_leaks(): + grads = grad_fn(params) + print(grads) + @parameterized.parameters( (["e4m3"] * 2 + ["e5m2"] * 4,), # Higher precision fwd, larger range bwd ([8] * 6,), diff --git a/aqt/jax/v2/flax/aqt_flax.py b/aqt/jax/v2/flax/aqt_flax.py index d9738bba..f5a35e69 100644 --- a/aqt/jax/v2/flax/aqt_flax.py +++ b/aqt/jax/v2/flax/aqt_flax.py @@ -417,6 +417,9 @@ def ret_dg( ) cfg.apply_custom_vjp_on_jax = False + cfg.fwd.dg_quantizer.init_calibration() + cfg.dlhs.dg_quantizer.init_calibration() + cfg.drhs.dg_quantizer.init_calibration() out, (out_lhs_qt, out_rhs_qt) = aqt_flax_dg_core.dg_core_flax_lifted( lhs, rhs, lhs_qt, rhs_qt, dimension_numbers, self, cfg ) diff --git a/aqt/jax/v2/flax/delayed_scaling_calibration.py b/aqt/jax/v2/flax/delayed_scaling_calibration.py index 1de675c3..3285b1fe 100644 --- a/aqt/jax/v2/flax/delayed_scaling_calibration.py +++ b/aqt/jax/v2/flax/delayed_scaling_calibration.py @@ -42,16 +42,14 @@ def setup(self) -> None: CALIBRATION_STATS, "amax_history", # pylint: disable-next=protected-access - lambda: jax._src.core.mutable_array( - jnp.zeros((self.amax_history_length,)) - ), + lambda: jnp.zeros((self.amax_history_length,)), ) self.bound = self.variable( CALIBRATION_STATS, "bound", # pylint: disable-next=protected-access - lambda: jax._src.core.mutable_array(jnp.zeros((1,))), + lambda: jnp.zeros((1,)), ) def get_bound( @@ -76,19 +74,14 @@ def get_bound( # be mutating the arrays in place and don't want to do so accidentally quant_mode = context.quant_mode if context else utils.QuantMode.SERVE - bound_mutable_arr = self.bound.value - amax_history_mutable_arr = self.amax_history.value + prev_bound = self.bound.value + amax_history = self.amax_history.value - amax_history = amax_history_mutable_arr[:] - prev_bound = bound_mutable_arr[:] amax_from_history = jnp.max(amax_history, axis=0) new_bound = self.compute_bound(amax_from_history, prev_bound) new_history = self.compute_history(x, amax_history) - if quant_mode in [utils.QuantMode.TRAIN, utils.QuantMode.CALIBRATE]: - bound_mutable_arr[:] = new_bound[:] - amax_history_mutable_arr[:] = new_history[:] return new_bound.reshape((1,) * len(x.shape)) def get_scale_and_bias_and_sparsity( @@ -122,3 +115,65 @@ def init_calibration(self): # variables are initialized properly once. Else, they get "recreated" # on each use. self.amax_history # pylint: disable=pointless-statement + + +def ceil_to_po2(scale: jnp.ndarray) -> jnp.ndarray: + # With floor the biggest value (we are using jnp.max) is in the range of + # clipping and therefore have a correct gradient. + scale = 2 ** jnp.floor(jnp.log2(jax.lax.reciprocal(scale))) + scale = jax.lax.reciprocal(scale) + return scale + + +@utils.flax_slots_kw_only_dataclass +class AbsMaxCalibration(calibration.Calibration): + """Simple max(abs(x)) calibration. + + Attributes: + clipping_scale: Set it to something like 0.3, 0.1, 0.03. If clipping_scale < + 1.0, setting IntSymmetric.clip_gradient=True is likely to be important. + """ + + clipping_scale: None | float = None + + def get_scale_and_bias_and_sparsity( + self, + x: jnp.ndarray, + shared_axes: None | Sequence[utils.AxisIdx], + numerics_: numerics.AqtNumerics, + context: None | utils.Context = None, + ) -> tuple[list[jnp.ndarray], list[jnp.ndarray], None | jnp.ndarray]: + """Calibration. + + Args: + x: The input tensor. + shared_axes: Axes that share a calibration bound. For AbsMaxCalibration, + it should not be None. + numerics_: An `AqtNumerics` object containing information regarding + quantization. Used to create the scale and bias arrays. + context: The quantization context. + + Returns: + The scale tensor containing the scale values for each group (can + potentially be a subchannel). Its shape will be the same as `x.shape` but + with `shared_axes` collapsed to 1. Bias is not supported. + """ + del context + msg = ( + "Perhaps you are using DequantMode.THIS_INPUT (fake_quant) and forgot" + " to set them." + ) + assert shared_axes is not None, msg + dtype = self.dtype if self.dtype is not None else x.dtype + + # NOTE: If you use a clipping_scale, consider using clip and clip_gradient + # in int_numerics.IntSymmetric. + abs_max = jnp.max(jnp.abs(x), axis=shared_axes, keepdims=True) + # TODO(yichizh): the zero filtering is not needed anymore because inf is + # filtered when calculating the reciprocal of scaling factor + abs_max = jnp.where(abs_max == 0.0, jnp.ones_like(abs_max), abs_max) + bound = abs_max * self.clipping_scale if self.clipping_scale else abs_max + + scale = bound / numerics_.get_quant_bound() + scale = ceil_to_po2(scale) if self.po2_scale else scale + return [scale.astype(dtype)], [], None