google / aqt

Apache License 2.0
248 stars 25 forks source link

Reduce checkpoint size for Flax models #720

Open liamclarkza opened 1 day ago

liamclarkza commented 1 day ago

Hi guys, thanks for the work on this library.

I am trying to reduce the checkpoint size and memory overhead of a model using AQT. Currently, when we quantize the parameters of a Flax model for serving using the QuantMode.CONVERT, the original parameters remain.

Is there any way to use AQT with Flax so that we don't need to keep the original unquantised weights in checkpoints when serving?

I have tried manually reducing the size of the parameter Pytree by removing or replacing the original kernels in the Pytree with placeholder values, but these approaches have been unsuccessful (see the example code below).

  1. Is it currently possible to reduce the size of the Flax params (and therefore checkpoints)?
  2. If so, are there any small idiomatic examples of how to do this?
  3. If not, are there plans to support this in future?
import functools
from pprint import pprint

import aqt.jax.v2.config as aqt_config
import flax.linen as nn
import jax
import jax.numpy as jnp
from aqt.jax.v2.flax import aqt_flax
from jax._src.tree_util import DictKey

class MlpBlock(nn.Module):
    aqt_cfg: aqt_config.DotGeneral | None = None
    quant_mode: aqt_flax.QuantMode = aqt_flax.QuantMode.TRAIN

    @nn.compact
    def __call__(self, inputs):
        dense_dg = functools.partial(
            aqt_flax.AqtDotGeneral,
            self.aqt_cfg,
            # In nn.Dense, it is RHS that has the kernel.
            rhs_quant_mode=self.quant_mode,
            rhs_freeze_mode=aqt_flax.FreezerMode.CALIBRATION_AND_VALUE
        )
        x = nn.Dense(dot_general_cls=dense_dg, features=3)(inputs)
        x = nn.relu(x)
        x = nn.Dense(dot_general_cls=dense_dg, features=3)(x)
        return x

int8_config = aqt_config.fully_quantized(fwd_bits=8, bwd_bits=8)

def get_pytree_memory_size(pytree):
    # estimate for size of pytree - not taking into account placeholder value size
    leaves, _ = jax.tree_util.tree_flatten(pytree)
    return sum(leaf.nbytes for leaf in leaves if leaf.dtype != jnp.dtype('O'))

# 1. Get params for the model
mlp = MlpBlock()
params = mlp.init(jax.random.key(0), jnp.ones((1, 10)))
print('Original params:')
pprint(jax.tree_util.tree_map(lambda x: x.shape, params))
print('Memory size of original params:', get_pytree_memory_size(params))
# -> Memory size of original params: 180

# 2. Convert the model to int8 - requires a dummy pass and mutable=True
mlp_convert = MlpBlock(
    aqt_cfg=int8_config,
    quant_mode=aqt_flax.QuantMode.CONVERT,

)
_, converted_params = mlp_convert.apply(
    params,
    jnp.ones((1, 10)),
    rngs={'params': jax.random.key(0)},
    mutable=True
)
print('Converted params:')
pprint(jax.tree_util.tree_map(lambda x: x.shape, converted_params))
print('Memory size of converted params:', get_pytree_memory_size(converted_params))
# -> Memory size of converted params: 243

# 3. Use the converted params to run the model
mlp_serve = MlpBlock(
    aqt_cfg=int8_config,
    quant_mode=aqt_flax.QuantMode.SERVE,
)
out = mlp_serve.apply(
    converted_params,
    jnp.ones((1, 10)),
    rngs={'params': jax.random.key(0)},
)
# This works :)

# 4. Try remove redundant weights for kernel in converted params by setting to None
params_no_kernel = jax.tree_util.tree_map_with_path(
    lambda kp, x: None if DictKey('kernel') in kp else x,
    converted_params
)
print('Params without kernel:')
pprint(jax.tree_util.tree_map(lambda x: (x.dtype, x.shape), params_no_kernel))
print('Memory size of params without kernel:', get_pytree_memory_size(params_no_kernel))
# -> Memory size of params without kernel: 87

mlp_serve = MlpBlock(
    aqt_cfg=int8_config,
    quant_mode=aqt_flax.QuantMode.SERVE,
)
out = mlp_serve.apply(
    params_no_kernel,
    jnp.ones((1, 10)),
    rngs={'params': jax.random.key(0)},
)
# When running apply() we get the following:
# AttributeError: 'NoneType' object has no attribute 'shape' when calling make_aqt_dg()

# 5. Try remove redundant weights for kernel in converted params using shapeDtypeStruct
params_no_kernel = jax.tree_util.tree_map_with_path(
    lambda kp, x: jax.ShapeDtypeStruct(x.shape, type(x)) if DictKey('kernel') in kp else x,
    converted_params
)
print('Params without kernel:')
pprint(jax.tree_util.tree_map(lambda x: (x.dtype, x.shape), params_no_kernel))
print('Memory size of params without kernel:', get_pytree_memory_size(params_no_kernel))
# -> Memory size of params without kernel: 87

mlp_serve = MlpBlock(
    aqt_cfg=int8_config,
    quant_mode=aqt_flax.QuantMode.SERVE,
)
out = mlp_serve.apply(
    params_no_kernel,
    jnp.ones((1, 10)),
    rngs={'params': jax.random.key(0)},
)
# When running apply() we get the following:
# TypeError: Value 'ShapeDtypeStruct(shape=(10, 3), dtype=object)' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
mar-muel commented 1 day ago

I'm facing the same issue. As a workaround I'm currently injecting my own Dense layers which only depend on scale and value parameters, but I'm not sure if this is the way to go. This should definitely be part of the library.

liamclarkza commented 1 day ago

@mar-muel, from what I can see, I think this is probably the best way to do this at the moment given that the current implementation with Flax requires a concrete Jax array for the original kernel params. I agree that it should be part of the library. For me, having the ability to reduce the memory overhead and checkpoint size for model serving are some of the main benefits of quantization. Hopefully this is in the pipeline for AQT.