kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.26k stars 890 forks source link

on number of training tokens of gpt-j-6b and gpt-neox-20b #235

Open xiaoda99 opened 2 years ago

xiaoda99 commented 2 years ago

The blog reported that: "The model was trained on 400B tokens from The Pile dataset with 800GB text." But I'm confused about this figure, because also in the blog: "Throughput of the 6B GPT-J for training (151k tokens/s) is faster than the 2.7B GPT-Neo (148k tokens/s) on the same hardware (TPU v3-256 pod)", "GPT-J training took roughly five weeks with TPU v3-256." So shoudn't it be 151k 3600 24 7 5 = 456G tokens?

I found that the training tokens for gpt-neox-20b also seems to be very small: https://github.com/EleutherAI/gpt-neox/issues/654

Where am I wrong? If gpt-j-6b and gpt-neox-20b were trained with ~1000 times fewer tokens than similar-sized gpt-3 curie (300B tokens), how could they achieve comparable or better eval results than gpt-3 curie?