kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

sequence_length=2049 or 2048? #175

Closed leejason closed 2 years ago

leejason commented 2 years ago

Should the sequence_length be 2049 or 2048? In gpt-neo, the chunk_size is 2048 for the split_list() function, but it is 2049 in your repository? Why?

You can use this script (the input expects exactly the same format as gptneo) https://github.com/EleutherAI/gpt-neo/blob/master/data/create_tfrecords.py

Originally posted by @kingoflolz in https://github.com/kingoflolz/mesh-transformer-jax/issues/68#issuecomment-886067933

vfbd commented 2 years ago

2049 is likely correct here. The trainer uses the first 2048 tokens of each 2049-token chunk as the "context" and the last 2048 tokens of each 2049-token chunk as the "target" so that the first token in the chunk predicts the second token in the chunk, the second predicts the third, and the 2048th predicts the 2049th. That's 2048 pairs of tokens in total.

Also, if you look at the bottom of the gpt-neo create_tfrecords script, you can see that it's increasing whatever you set the chunk size as by 1, so the default there is also 2049.

kingoflolz commented 2 years ago

Thats right, thanks @vfbd for the answer

leejason commented 2 years ago

It is illuminating & helpful. Thank you very much.