Open liamclarkza opened 2 months ago
I'm facing the same issue. As a workaround I'm currently injecting my own Dense layers which only depend on scale and value parameters, but I'm not sure if this is the way to go. This should definitely be part of the library.
@mar-muel, from what I can see, I think this is probably the best way to do this at the moment given that the current implementation with Flax requires a concrete Jax array for the original kernel params. I agree that it should be part of the library. For me, having the ability to reduce the memory overhead and checkpoint size for model serving are some of the main benefits of quantization. Hopefully this is in the pipeline for AQT.
Hi guys, thanks for the work on this library.
I am trying to reduce the checkpoint size and memory overhead of a model using AQT. Currently, when we quantize the parameters of a Flax model for serving using the
QuantMode.CONVERT
, the original parameters remain.Is there any way to use AQT with Flax so that we don't need to keep the original unquantised weights in checkpoints when serving?
I have tried manually reducing the size of the parameter Pytree by removing or replacing the original kernels in the Pytree with placeholder values, but these approaches have been unsuccessful (see the example code below).