kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.26k stars 890 forks source link

Finetuning GPT Neo 20B Using TPU V3-8s #227

Open nikhilanayak opened 2 years ago

nikhilanayak commented 2 years ago

Would it be possible to finetune 20B using this repo and the TPU V3-8's? If so, how many TPUs would be needed and how would I have to change the code to make it work with more than 1 TPU?