pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.33k stars 438 forks source link

Mask eos token for packed dataset #1177

Open iankur opened 4 months ago

iankur commented 4 months ago

Hi, My current understanding of training with packed dataset is that if we have two sequences packed together, we will compute loss for eos token from the first sequence as input and bos token from the second sequence as target. In the recipes, I can find masking only for the padded tokens towards the end of pack and not for these in between extra tokens. Is this correct?

RdoubleA commented 4 months ago

@iankur Great question. PackedDataset is a wrapper class around the underlying dataset, so the labels it packs is entirely dependent on the dataset class you are using. All of our dataset classes determine labels for loss based on the mask. This is created by the tokenizer's tokenize_messages method.

If we look at Phi3 as an example, both BOS and EOS are masked out from the loss depending on if the Message/sequence as a whole is masked out. Generally, user messages are masked out and don't contribute to the loss because you only want the model to learn to predict the assistant messages. So BOS tokens from a user message are masked but the EOS after an assistant message is not masked.

Packed datasets just takes these same rules and concatenates all the tokens/labels together in a single pack. It additionally adds masking for the padding tokens, as you pointed out.

iankur commented 4 months ago

@RdoubleA I was looking into text completion dataset which does not have mask that you mention here

All of our dataset classes determine labels for loss based on the mask

Otherwise, I agree with what you said and was hoping that will be the case for all datasets.

RdoubleA commented 4 months ago

For text completion / continued pre-training, there's no concept of "user" and "assistant" or "input" and "output", there is just a document that you want the model to learn the semantics of. So the labels will be a direct copy of the input tokens offset by one so that the model can learn to predict the next token.

input: <s> This is  a   cool  fact  about a   cat from   a  textbook about cats. </s>
        |   |    |  |    |     |     |    |    |   |     |      |      |    |
pred:  This is   a cool fact  about  a    cat from a textbook about  cats  </s>

None of these tokens should be masked out.

iankur commented 4 months ago

My question was about token </s> at the end of first sequence and before the start of next sequence when we use packing with text completion dataset. In this case, we will have something like

input: <s> This is  a   cool  fact  about a   cat from   a  textbook about cats. </s>  <s> This is ...
        |   |    |  |    |     |     |    |    |   |     |      |      |    |      |    |  
pred:  This is   a cool fact  about  a    cat from a textbook about  cats  </s>   <s> This  ...

and there is no way to mask this token or loss on it afterwards with current code. I am not sure what the desired behavior should be but it can affect other tasks such as statistics of activations. Please feel free to close the issue if this not relevant.

RdoubleA commented 4 months ago

Ah I see, so if I understand correctly you want to mask out the in between </s> tokens for packed sequences, but not the <s> tokens?

You're right that it can impact activations and model output - but I would still say you want to keep the loss for those tokens. The model needs to learn when a document ends even within a packed sequence so the data should keep those boundaries. But let me know if you have a slightly different use case or if I'm still missing something.