jzhang38 / TinyLlama

The TinyLlama project is an open endeavor to pretrain a 1.1B Llama model on 3 trillion tokens.
Apache License 2.0
7.83k stars 460 forks source link

Why do we not set the `ignore_index` of `FusedCrossEntropy` to `bos_id`? #83

Closed larrylawl closed 1 year ago

larrylawl commented 1 year ago

Hi,

Can I check why did you not ignore the loss for the bos token (<s>)?

https://github.com/jzhang38/TinyLlama/blob/c53075b679c9a97f96562052689e0043120a5fd5/pretrain/tinyllama.py#L197

I noticed that the preprocessing causes the remainder of the binary file to be the bos token (<s>).

https://github.com/jzhang38/TinyLlama/blob/c53075b679c9a97f96562052689e0043120a5fd5/scripts/prepare_slimpajama.py#L69

Consequently, my model checkpoint (not TinyLlama's) outputs poor qualitative results:

Prompt: "(very long text of 1027 tokens). How many yards longer was the longest passing touchdown than the shortest?"
Output: "<s>ending<s>end<s>ent<s>ended"

Interestingly, my model of previous checkpoint (100b tokens before) performed okay.

I'm trying to fix this by specifying the loss function to ignore the <s> idx (i.e. 1). I think this is a correct fix, but i'm not sure if it fixes the underlying issue (the issue should have plagued our model from the start, why did it only happen at this iter step?).

jzhang38 commented 1 year ago

I noticed that the preprocessing causes the remainder of the binary file to be the bos token (<s>).

Yes I think you are right here. https://github.com/jzhang38/TinyLlama/blob/c53075b679c9a97f96562052689e0043120a5fd5/scripts/prepare_slimpajama.py#L85

If you have 64 CPU cores. prepare_slimpajama.py will initiate 64 processes, each with a PackedDatasetBuilder and call https://github.com/jzhang38/TinyLlama/blob/c53075b679c9a97f96562052689e0043120a5fd5/scripts/prepare_slimpajama.py#L26 https://github.com/jzhang38/TinyLlama/blob/c53075b679c9a97f96562052689e0043120a5fd5/scripts/prepare_slimpajama.py#L69 That means each process will leave a chunk file that is not fully filled with text tokens but rather with some sep tokens. https://github.com/jzhang38/TinyLlama/blob/c53075b679c9a97f96562052689e0043120a5fd5/lit_gpt/packed_dataset.py#L77 https://github.com/jzhang38/TinyLlama/blob/c53075b679c9a97f96562052689e0043120a5fd5/lit_gpt/packed_dataset.py#L91 So we may have 64 files with some sep tokens remaining.

but i'm not sure if it fixes the underlying issue (the issue should have plagued our model from the start, why did it only happen at this iter step?).

I think it is because 64 files with some sep tokens remaining is a relatively small portion compared with the entire pretraining corpus(450k small bin files after processing), especially when you consider the small size of each chunk. So I do not know why your second checkpoint became really bad. Maybe it is just this specific prompt? Does the benchmark performance degrade significantly?

These are some of my preliminary thoughts. Haven't looked very deeply into it yet. Thanks for spotting this out. We will fix it soon. For example, we can opt to not call builder.write_reminder() at all.

larrylawl commented 1 year ago

Thanks for your reply Peiyuan!

I think it is because 64 files with some sep tokens remaining is a relatively small portion compared with the entire pretraining corpus(450k small bin files after processing), especially when you consider the small size of each chunk.

It's true that the % of files with some bos token remaining is relatively small. The chunk size is actually quite big (i.e. 2049 * 1028). This means that once a process loads the "problematic" binary chunk, it'll use this file for the next 1024 iterations.

https://github.com/jzhang38/TinyLlama/blob/072536c460293387531eb08a3a0275c6b2e1032c/scripts/prepare_slimpajama.py#L76

But you are right that as the % of files is relatively small, it shouldn't affect. I'll let you know if I managed to fix the bug. Thanks for your help anyway!

Maybe it is just this specific prompt? Does the benchmark performance degrade significantly?

It degraded significantly across the instruct-eval benchmark.

jzhang38 commented 1 year ago

https://github.com/jzhang38/TinyLlama/pull/85