unslothai / unsloth

Finetune Llama 3, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
12.28k stars 796 forks source link

Possible CrossEntropy optimization #569

Open dragosconst opened 1 month ago

dragosconst commented 1 month ago

I have noticed that the CE bwd kernel loads the elements of the dloss tensor from HBM to the SM memory. From some experiments, it seems to me that the dloss tensor is always filled with the scaling used for reducing the losses. For example, assuming no ignored tokens, it would be a tensor filled with 1/seq_len.

My team and I are using a custom version of the bwd kernel where we just pass the scaling constant, and avoid loading the dloss tensor elements. In our case we scale with something like 1/(non_ignored_tokens * acc_steps), but regardless it passes all of our numerical tests with regards to correctness.

I was wondering if there's ever any situation where loading the dloss elements makes a difference? I suppose someone could use custom weighting for each token, although I'm not familiar with any technique that does it. In the current repo code it seems that the default behavior is to reduce all the non-ignored tokens, so in this case it should be a tensor filled with the same values everywhere.

danielhanchen commented 1 month ago

Oh interesting! Let me check and get back to you! Thanks for the idea!!