kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

Finetuning and training minimum requirements #156

Closed AnassKartit closed 2 years ago

AnassKartit commented 2 years ago

Hello, What is the minimum requirements for finetunning and inference? is TPU mandatory? what are the costs/time needed to train a small dataset used in the finetuning article?

kingoflolz commented 2 years ago

This codebase is designed for inference and finetuning on TPU. The cost of TPU instances are available from google, and fine tuning runs at around 5000 tokens per second on a TPU v3-8