axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.48k stars 808 forks source link

Loss Scaling by Unmasked Token Count #508

Closed grimulkan closed 11 months ago

grimulkan commented 1 year ago

⚠️ Please check that this feature request hasn't been suggested before.

πŸ”– Feature description

I have a data set with 2 types of data items:

train_on_inputs = false, so the Qs don't contribute to training.

I would like Type 2 items to contribute more to the loss/gradient computation, because they have more training tokens (or equivalently, more Q&A entries), irrespective of the batching.

This is somewhat random if I understand the default behavior correctly. Egs., I may get a batch with all Type 1 entries, and another batch with all Type 2 entries, and both batches contribute equally to the gradient update (vs the 2nd batch contributing more). On the other hand, if I get a batch with a mix of both types of entries, then the higher token count of Type 2 will correctly give it more contribution to the loss (I think).

If I had a giant batch, it would likely smooth this variation out statistically, but I don't have the VRAM for that. I don't think gradient accumulation normalizes the gradient updates at the end of the accumulation step (each mini-batch update is still normalized, right?)

I am not sure how this can be done or the best way to handle this, or if it should even be done, so any suggestions/discussion welcome!

I don't know if this is even the correct way to compute the loss for this situation, but it feels like it should be. Happy to be corrected.

βœ”οΈ Solution

I think the trainer_weighted_loss() function in trainer.py does the opposite of what I want, in that would normalize the contributions from Type 1 and Type 2 entries by the number of Q & A entries (which is roughly equivalent to normalizing by unmasked token length for my case).

This would remove the randomness with the batching, but it corrects for it in the other direction. I'd like to have Type 2 contribute more, not the same as Type 1, no matter what it is batched with. Also, the weighted loss feature seems to be currently unused (it was added in PR# possibly related to sample_packing, but not currently utilized?)

This function currently does (I think):

If implemented this way instead: loss scale = (# of Q&A entries)/max_entries, for some pre-defined number max_entires, then

I am not sure if this messes anything up, since the Type 1 entries will have potentially much smaller loss, and there is this extra max_entries hyper parameter which is annoying.

A completely different approach would be to somehow not normalize the sum of the gradients from each data item per batch, and only normalize the gradients at the gradient accumulation stage. With high gradient accumulation, statistics will do the work for us. That said, I could be wrong in my understanding of how gradient accumulation works...

❓ Alternatives

πŸ“ Additional Context

No response

Acknowledgements

grimulkan commented 1 year ago

I realized it can be done by modifying the loss function. Let me do some testing and see if any actual feature support is needed.

grimulkan commented 11 months ago

Forgot about this, so returning to close it.

Turns out I am dumb and Llama overloads cross-entropy loss computation to already be scaled by the # of tokens, as opposed to some other implementations that average it out over the sequence length.

winglian commented 11 months ago

@Grimulkan thanks for updating!

grimulkan commented 10 months ago

Just wanted to point to further (past) discussion: https://github.com/huggingface/transformers/issues/24725 It's still an upstream issue IMO.