Are there any flags or jax config necessary so that int4 actually uses less memory than int8? Have you observed less memory usage with int4 from this PR?
It will require layout tuning specific to the platform. Will share numbers as we evaluate/optimize for llama70b.
It will require layout tuning specific to the platform. Will share numbers as we evaluate/optimize for llama70b.