You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to reduce the checkpoint size and memory overhead of a model using AQT. Currently, when we quantize the parameters of a Flax model for serving using the QuantMode.CONVERT, the original parameters remain.
Is there any way to use AQT with Flax so that we don't need to keep the original unquantised weights in checkpoints when serving?
I have tried manually reducing the size of the parameter Pytree by removing or replacing the original kernels in the Pytree with placeholder values, but these approaches have been unsuccessful (see the example code below).
Is it currently possible to reduce the size of the Flax params (and therefore checkpoints)?
If so, are there any small idiomatic examples of how to do this?
If not, are there plans to support this in future?
importfunctoolsfrompprintimportpprintimportaqt.jax.v2.configasaqt_configimportflax.linenasnnimportjaximportjax.numpyasjnpfromaqt.jax.v2.flaximportaqt_flaxfromjax._src.tree_utilimportDictKeyclassMlpBlock(nn.Module):
aqt_cfg: aqt_config.DotGeneral|None=Nonequant_mode: aqt_flax.QuantMode=aqt_flax.QuantMode.TRAIN@nn.compactdef__call__(self, inputs):
dense_dg=functools.partial(
aqt_flax.AqtDotGeneral,
self.aqt_cfg,
# In nn.Dense, it is RHS that has the kernel.rhs_quant_mode=self.quant_mode,
rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE
)
x=nn.Dense(dot_general_cls=dense_dg, features=3)(inputs)
x=nn.relu(x)
x=nn.Dense(dot_general_cls=dense_dg, features=3)(x)
returnxint8_config=aqt_config.fully_quantized(fwd_bits=8, bwd_bits=8)
defget_pytree_memory_size(pytree):
# estimate for size of pytree - not taking into account placeholder value sizeleaves, _=jax.tree_util.tree_flatten(pytree)
returnsum(leaf.nbytesforleafinleavesifleaf.dtype!=jnp.dtype('O'))
# 1. Get params for the modelmlp=MlpBlock()
params=mlp.init(jax.random.key(0), jnp.ones((1, 10)))
print('Original params:')
pprint(jax.tree_util.tree_map(lambdax: x.shape, params))
print('Memory size of original params:', get_pytree_memory_size(params))
# -> Memory size of original params: 180# 2. Convert the model to int8 - requires a dummy pass and mutable=Truemlp_convert=MlpBlock(
aqt_cfg=int8_config,
quant_mode=aqt_flax.QuantMode.CONVERT,
)
_, converted_params=mlp_convert.apply(
params,
jnp.ones((1, 10)),
rngs={'params': jax.random.key(0)},
mutable=True
)
print('Converted params:')
pprint(jax.tree_util.tree_map(lambdax: x.shape, converted_params))
print('Memory size of converted params:', get_pytree_memory_size(converted_params))
# -> Memory size of converted params: 243# 3. Use the converted params to run the modelmlp_serve=MlpBlock(
aqt_cfg=int8_config,
quant_mode=aqt_flax.QuantMode.SERVE,
)
out=mlp_serve.apply(
converted_params,
jnp.ones((1, 10)),
rngs={'params': jax.random.key(0)},
)
# This works :)# 4. Try remove redundant weights for kernel in converted params by setting to Noneparams_no_kernel=jax.tree_util.tree_map_with_path(
lambdakp, x: NoneifDictKey('kernel') inkpelsex,
converted_params
)
print('Params without kernel:')
pprint(jax.tree_util.tree_map(lambdax: (x.dtype, x.shape), params_no_kernel))
print('Memory size of params without kernel:', get_pytree_memory_size(params_no_kernel))
# -> Memory size of params without kernel: 87mlp_serve=MlpBlock(
aqt_cfg=int8_config,
quant_mode=aqt_flax.QuantMode.SERVE,
)
out=mlp_serve.apply(
params_no_kernel,
jnp.ones((1, 10)),
rngs={'params': jax.random.key(0)},
)
# When running apply() we get the following:# AttributeError: 'NoneType' object has no attribute 'shape' when calling make_aqt_dg()# 5. Try remove redundant weights for kernel in converted params using shapeDtypeStructparams_no_kernel=jax.tree_util.tree_map_with_path(
lambdakp, x: jax.ShapeDtypeStruct(x.shape, type(x)) ifDictKey('kernel') inkpelsex,
converted_params
)
print('Params without kernel:')
pprint(jax.tree_util.tree_map(lambdax: (x.dtype, x.shape), params_no_kernel))
print('Memory size of params without kernel:', get_pytree_memory_size(params_no_kernel))
# -> Memory size of params without kernel: 87mlp_serve=MlpBlock(
aqt_cfg=int8_config,
quant_mode=aqt_flax.QuantMode.SERVE,
)
out=mlp_serve.apply(
params_no_kernel,
jnp.ones((1, 10)),
rngs={'params': jax.random.key(0)},
)
# When running apply() we get the following:# TypeError: Value 'ShapeDtypeStruct(shape=(10, 3), dtype=object)' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
The text was updated successfully, but these errors were encountered:
I'm facing the same issue. As a workaround I'm currently injecting my own Dense layers which only depend on scale and value parameters, but I'm not sure if this is the way to go. This should definitely be part of the library.
@mar-muel, from what I can see, I think this is probably the best way to do this at the moment given that the current implementation with Flax requires a concrete Jax array for the original kernel params. I agree that it should be part of the library. For me, having the ability to reduce the memory overhead and checkpoint size for model serving are some of the main benefits of quantization. Hopefully this is in the pipeline for AQT.
Hi guys, thanks for the work on this library.
I am trying to reduce the checkpoint size and memory overhead of a model using AQT. Currently, when we quantize the parameters of a Flax model for serving using the
QuantMode.CONVERT
, the original parameters remain.Is there any way to use AQT with Flax so that we don't need to keep the original unquantised weights in checkpoints when serving?
I have tried manually reducing the size of the parameter Pytree by removing or replacing the original kernels in the Pytree with placeholder values, but these approaches have been unsuccessful (see the example code below).
The text was updated successfully, but these errors were encountered: