google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.39k stars 247 forks source link

Adding option for int4 quantization to kvcache. #737

Closed singh-mitali closed 3 weeks ago

singh-mitali commented 3 weeks ago

LGTM.

Are there any flags or jax config necessary so that int4 actually uses less memory than int8? Have you observed less memory usage with int4 from this PR?

It will require layout tuning specific to the platform. Will share numbers as we evaluate/optimize for llama70b.