minimaxir / aitextgen

A robust Python tool for text-based AI training and generation using GPT-2.
https://docs.aitextgen.io
MIT License
1.84k stars 220 forks source link

Consider supporting JAX for faster TPU training #148

Open minimaxir opened 3 years ago

minimaxir commented 3 years ago

Per https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#runtime-evaluation , JAX is about twice as fast to train on a TPU than the corresponding PyTorch models, therefore it may be worthwhile to add support for it.

However it's dependent on when Hugging Face adds Trainer support as manually setting up the loops is not easy.