kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.26k stars 890 forks source link

Quantization for training / finetuning #254

Open torphix opened 1 year ago

torphix commented 1 year ago

Hi! Thanks for the lib and tutorial, it is very informative.

With respect to finetuning would it be worth quantizing the model first to fp16 or even int8 before beginning training? As this might lead to better accuracy when compared to quantizing after the model has been finetuned?

Thanks