google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458 stars 69 forks source link

Jax + tpu and AQT int8 train model loss is abnormal #71

Open Lisennlp opened 9 months ago

Lisennlp commented 9 months ago

I used the aqt_einsum function in the code to only quantify the qk sccore, and then trained the model. However, I found that the loss dropped very slowly after training to a certain number of steps (such as 200 steps), which was quite different from the loss curve trained by bfloat16. Am I missing something? For example, does backward need some additional processing? ps: I train model on jax==0.4.23 and tpu v5p-8

In other words, is there a training example for AQT int8 in pax?