meta-llama / llama3

The official Meta Llama 3 GitHub site
Other
22.97k stars 2.43k forks source link

Question about tokenizer #42

Open odegeasslbc opened 2 months ago

odegeasslbc commented 2 months ago

Hi guys, thanks for open-sourcing this great work! It seems LLama3 is using “right” padding and using “eos_token“ as the “padding_token”. Could you help verify that if I want to train this model, what should be the padding side and padding token? because it is different than many other LLMs, like Gemma is using “left” padding and has a dedicated padding_token, so does many other models. So here i just want to double check and make sure I'm doing the correct config when training/fine-tuning Llama3.

ruanslv commented 2 months ago

It seems LLama3 is using “right” padding and using “eos_token“ as the “padding_token”.

Please do not mix up this line of inference code with any kind of training setup: https://github.com/meta-llama/llama3/blob/main/llama/generation.py#L155

That padding is just a trick being used in order to have the model inference code more easily deal with variable sequence lengths in the input.

We are using pad_id = -1 there, not "eos". And given all the filters that we use before actually calling model.forward (e.g. tokens[:, prev_pos:cur_pos]), the model never actually sees any token = -1 (if it would, then the embedding lookup would fail because -1 is not a valid token id).

here i just want to double check and make sure I'm doing the correct config when training/fine-tuning Llama3

This question can be dependent on the fine-tuning library that you are using. During our llama3 fine-tuning, the pad token being used and the padding direction (right or left) is irrelevant, because we use masks to compute the loss. You are not able to use -1 because you will cause embedding lookup error but any token like pad_id = 0 works fine. Then what we do is only consider loss values for the tokens we care about (so you exclude pad tokens, and if you are doing supervised finetuning you may want to exclude the prompt tokens too and only train based on the answer tokens).

So say you have 3 prompt tokens, 4 answer tokens, and 2 pad tokens (right padding). Then you can create a mask:

mask = [False, False, False, True, True, True, True, False, False]

So for example, if doing cross-entropy loss, you can use the mask to filter out tokens you don't care about (not sure if the shapes all match but just to illustrate):

token_loss = F.cross_entropy(prediction, target, reduction="none")
token_loss = token_loss * mask
# and then do the loss reduction you care about

torchtune library has similar logic: https://github.com/pytorch/torchtune/blob/a9180b537186b7484cd05969f35f05f28ae2c622/torchtune/data/_common.py#L7

Hope this helps!

odegeasslbc commented 2 months ago

Thanks Ruan for your answer!