Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle the jax.promote_dtype failure for [int4, fp8_e4m3, fp8_e5m2] when using aqt_einsum #759

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions aqt/jax/v2/flax/aqt_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,61 @@ def _get_singleton_axes(x: jnp.ndarray) -> list[utils.AxisIdx]:
return qt


def aqt_promote_dtype(
lhs_in: jnp.ndarray, rhs_in: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Promotes the dtype of lhs_in and rhs_in.

Args:
lhs_in: Left-hand-side array.
rhs_in: Right-hand-side array.

Returns:
A tuple of the promoted lhs_in and rhs_in.

We create a list of dtypes and hand-hold them because promote_dtype fails for
these dtypes.
"""

manual_promotion_dtypes = [jnp.int4, jnp.float8_e4m3fn, jnp.float8_e5m2]
if (
lhs_in.dtype in manual_promotion_dtypes
and rhs_in.dtype in manual_promotion_dtypes
):
if lhs_in.dtype == rhs_in.dtype:
pass
else:
lhs_in = (
jnp.float32(lhs_in)
if lhs_in.dtype == jnp.int4
else jnp.bfloat16(lhs_in)
)
rhs_in = (
jnp.float32(rhs_in)
if rhs_in.dtype == jnp.int4
else jnp.bfloat16(rhs_in)
)
elif lhs_in.dtype in manual_promotion_dtypes:
lhs_in = (
jnp.float32(lhs_in)
if lhs_in.dtype == jnp.int4
else jnp.bfloat16(lhs_in)
)
elif rhs_in.dtype in manual_promotion_dtypes:
rhs_in = (
jnp.float32(rhs_in)
if rhs_in.dtype == jnp.int4
else jnp.bfloat16(rhs_in)
)

if (
lhs_in.dtype not in manual_promotion_dtypes
and rhs_in.dtype not in manual_promotion_dtypes
):
lhs_in, rhs_in = nn.dtypes.promote_dtype(lhs_in, rhs_in)
return lhs_in, rhs_in


class FreezerMode(enum.Enum):
NONE = 1
CALIBRATION = 2
Expand Down Expand Up @@ -603,15 +658,7 @@ def __call__(
# from being rejected by assertions in aqt_dot_general.py, line 522-526 and
# 414.
# TODO: b/322111904 - Handle this in more proper way.
# We hand-hold int4 because promote_dtype(int4, x) fails.
# (To avoid unintended promotion, 4-bit integers do not support
# implicit promotion.)
if lhs_in.dtype == jnp.int4:
lhs_in = jnp.float32(lhs_in)
if rhs_in.dtype == jnp.int4:
rhs_in = jnp.float32(rhs_in)
if lhs_in.dtype != jnp.int4 and rhs_in.dtype != jnp.int4:
lhs_in, rhs_in = nn.dtypes.promote_dtype(lhs_in, rhs_in)
lhs_in, rhs_in = aqt_promote_dtype(lhs_in, rhs_in)

# yes_swap = whether einsum swaps [lhs,rhs] when passing them to dot_general
einsum = functools.partial(aqt_dot_general.einsum, eqn=eqn)
Expand Down