Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
456
stars
68
forks
source link
Jax + tpu and AQT int8 train model loss is abnormal #71
I used the aqt_einsum function in the code to only quantify the qk sccore, and then trained the model. However, I found that the loss dropped very slowly after training to a certain number of steps (such as 200 steps), which was quite different from the loss curve trained by bfloat16. Am I missing something? For example, does backward need some additional processing?
ps: I train model on jax==0.4.23 and tpu v5p-8
In other words, is there a training example for AQT int8 in pax?
I used the aqt_einsum function in the code to only quantify the qk sccore, and then trained the model. However, I found that the loss dropped very slowly after training to a certain number of steps (such as 200 steps), which was quite different from the loss curve trained by bfloat16. Am I missing something? For example, does backward need some additional processing? ps: I train model on jax==0.4.23 and tpu v5p-8
In other words, is there a training example for AQT int8 in pax?