Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dummy #751

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

dummy #751

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions aqt/jax/v2/aqt_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,9 +1101,6 @@ def dg_core_vjp_fwd(
assert (
lhs.dtype == rhs.dtype
), f'Unmatched lhs and rhs dtype: {lhs.dtype} vs {rhs.dtype}'
cfg.fwd.dg_quantizer.init_calibration()
cfg.dlhs.dg_quantizer.init_calibration()
cfg.drhs.dg_quantizer.init_calibration()
ret, res = cfg.fwd(
lhs,
rhs,
Expand Down
42 changes: 42 additions & 0 deletions aqt/jax/v2/examples/flax_e2e_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
from aqt.jax.v2 import config
from aqt.jax.v2 import utils
from aqt.jax.v2.examples import flax_e2e_model
from aqt.jax.v2.flax import aqt_flax
from aqt.jax.v2.flax import aqt_flax_calibration
from aqt.jax.v2.flax import delayed_scaling_calibration
from flax import linen as nn
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -959,6 +961,46 @@ def forward(model, apply_fn):
logits_after_conversion, _ = forward(model_serving, serve_fn)
assert (logits_before_conversion == logits_after_conversion).all()

def test_simple(self):
aqt_cfg_dg = config.config_v4()
amax_history_length = 32
calibration_cls = functools.partial(
delayed_scaling_calibration.DelayedScalingCalibration,
amax_history_length=amax_history_length,
)
aqt_cfg_dg.fwd.dg_quantizer.lhs.calibration = calibration_cls
aqt_cfg_dg.fwd.dg_quantizer.rhs.calibration = calibration_cls
aqt_cfg_dg.dlhs.dg_quantizer.lhs.calibration = calibration_cls
aqt_cfg_dg.dlhs.dg_quantizer.rhs.calibration = calibration_cls
aqt_cfg_dg.drhs.dg_quantizer.lhs.calibration = calibration_cls
aqt_cfg_dg.drhs.dg_quantizer.rhs.calibration = calibration_cls

class MlpBlock(nn.Module):

@nn.compact
def __call__(self, inputs):
dot_general = aqt_flax.AqtDotGeneral(aqt_cfg_dg)
x = nn.Dense(dot_general=dot_general, features=1)(inputs)
return x

x = jnp.ones((10, 10))
y = jnp.ones((10, 1))

model = MlpBlock()
params = model.init({"params": jax.random.PRNGKey(0)}, x)

def loss_fn(params):
return jnp.mean(
jnp.abs(
y - model.apply(params, x, rngs={"params": jax.random.PRNGKey(0)})
)
)

grad_fn = jax.grad(loss_fn)
with jax.checking_leaks():
grads = grad_fn(params)
print(grads)

@parameterized.parameters(
(["e4m3"] * 2 + ["e5m2"] * 4,), # Higher precision fwd, larger range bwd
([8] * 6,),
Expand Down
3 changes: 3 additions & 0 deletions aqt/jax/v2/flax/aqt_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,9 @@ def ret_dg(
)

cfg.apply_custom_vjp_on_jax = False
cfg.fwd.dg_quantizer.init_calibration()
cfg.dlhs.dg_quantizer.init_calibration()
cfg.drhs.dg_quantizer.init_calibration()
out, (out_lhs_qt, out_rhs_qt) = aqt_flax_dg_core.dg_core_flax_lifted(
lhs, rhs, lhs_qt, rhs_qt, dimension_numbers, self, cfg
)
Expand Down
77 changes: 66 additions & 11 deletions aqt/jax/v2/flax/delayed_scaling_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,14 @@ def setup(self) -> None:
CALIBRATION_STATS,
"amax_history",
# pylint: disable-next=protected-access
lambda: jax._src.core.mutable_array(
jnp.zeros((self.amax_history_length,))
),
lambda: jnp.zeros((self.amax_history_length,)),
)

self.bound = self.variable(
CALIBRATION_STATS,
"bound",
# pylint: disable-next=protected-access
lambda: jax._src.core.mutable_array(jnp.zeros((1,))),
lambda: jnp.zeros((1,)),
)

def get_bound(
Expand All @@ -76,19 +74,14 @@ def get_bound(
# be mutating the arrays in place and don't want to do so accidentally
quant_mode = context.quant_mode if context else utils.QuantMode.SERVE

bound_mutable_arr = self.bound.value
amax_history_mutable_arr = self.amax_history.value
prev_bound = self.bound.value
amax_history = self.amax_history.value

amax_history = amax_history_mutable_arr[:]
prev_bound = bound_mutable_arr[:]
amax_from_history = jnp.max(amax_history, axis=0)

new_bound = self.compute_bound(amax_from_history, prev_bound)
new_history = self.compute_history(x, amax_history)

if quant_mode in [utils.QuantMode.TRAIN, utils.QuantMode.CALIBRATE]:
bound_mutable_arr[:] = new_bound[:]
amax_history_mutable_arr[:] = new_history[:]
return new_bound.reshape((1,) * len(x.shape))

def get_scale_and_bias_and_sparsity(
Expand Down Expand Up @@ -122,3 +115,65 @@ def init_calibration(self):
# variables are initialized properly once. Else, they get "recreated"
# on each use.
self.amax_history # pylint: disable=pointless-statement


def ceil_to_po2(scale: jnp.ndarray) -> jnp.ndarray:
# With floor the biggest value (we are using jnp.max) is in the range of
# clipping and therefore have a correct gradient.
scale = 2 ** jnp.floor(jnp.log2(jax.lax.reciprocal(scale)))
scale = jax.lax.reciprocal(scale)
return scale


@utils.flax_slots_kw_only_dataclass
class AbsMaxCalibration(calibration.Calibration):
"""Simple max(abs(x)) calibration.

Attributes:
clipping_scale: Set it to something like 0.3, 0.1, 0.03. If clipping_scale <
1.0, setting IntSymmetric.clip_gradient=True is likely to be important.
"""

clipping_scale: None | float = None

def get_scale_and_bias_and_sparsity(
self,
x: jnp.ndarray,
shared_axes: None | Sequence[utils.AxisIdx],
numerics_: numerics.AqtNumerics,
context: None | utils.Context = None,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray], None | jnp.ndarray]:
"""Calibration.

Args:
x: The input tensor.
shared_axes: Axes that share a calibration bound. For AbsMaxCalibration,
it should not be None.
numerics_: An `AqtNumerics` object containing information regarding
quantization. Used to create the scale and bias arrays.
context: The quantization context.

Returns:
The scale tensor containing the scale values for each group (can
potentially be a subchannel). Its shape will be the same as `x.shape` but
with `shared_axes` collapsed to 1. Bias is not supported.
"""
del context
msg = (
"Perhaps you are using DequantMode.THIS_INPUT (fake_quant) and forgot"
" to set them."
)
assert shared_axes is not None, msg
dtype = self.dtype if self.dtype is not None else x.dtype

# NOTE: If you use a clipping_scale, consider using clip and clip_gradient
# in int_numerics.IntSymmetric.
abs_max = jnp.max(jnp.abs(x), axis=shared_axes, keepdims=True)
# TODO(yichizh): the zero filtering is not needed anymore because inf is
# filtered when calculating the reciprocal of scaling factor
abs_max = jnp.where(abs_max == 0.0, jnp.ones_like(abs_max), abs_max)
bound = abs_max * self.clipping_scale if self.clipping_scale else abs_max

scale = bound / numerics_.get_quant_bound()
scale = ceil_to_po2(scale) if self.po2_scale else scale
return [scale.astype(dtype)], [], None