kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

how to speed up the inference time #196

Closed whoislimshady closed 2 years ago

whoislimshady commented 2 years ago

i am trying to deploy the gpt jax on gcp tpu v2-8 the generation takes around 35 sec which is way to slow i found few blogs which states that with 2-8 we can achieve latency of 5 sec can someone suggest me how to speed up the process and yes i am using the slim weights

safeeazeem commented 2 years ago

The first generation always takes time, and then after that its quicker. See table below where I tested the time it takes to generate tokens of different length. This was tested on a v2-8 TPU instance with slim weights.

Tokens | First Gen | Second Gen | Third Gen -- | -- | -- | -- 100 | 68s | 2.228s | 2.229s 150 | 38.78s | 3.3s | 3.3s 300 | 42s | 6.3s | 6.3s 512 | 45.9s | 10.7s | 10.7s 768 | 52.7s | 15.9s | 15.9s 1024 | 57.51s | 21.13s | 21.13s
whoislimshady commented 2 years ago

thanks @safeeazeem for the information really appreciate the effort