Open minimaxir opened 3 years ago
Per https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation , JAX is about twice as fast to train on a TPU than the corresponding PyTorch models, therefore it may be worthwhile to add support for it.
However it's dependent on when Hugging Face adds Trainer support as manually setting up the loops is not easy.
Per https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation , JAX is about twice as fast to train on a TPU than the corresponding PyTorch models, therefore it may be worthwhile to add support for it.
However it's dependent on when Hugging Face adds Trainer support as manually setting up the loops is not easy.