axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.48k stars 808 forks source link

support for llama multipack using updated code/patches #1754

Closed winglian closed 1 month ago

winglian commented 1 month ago

The attention monkey patch we have for llama is pretty old at this point and having to maintain it is a pain. Swapping to the updated unpad patch for flash attention, and did a slight refactor to continue to support the cross entropy loss and rms norm patches.

Screenshot 2024-07-15 at 8 47 35 AM

As we can see, it's slightly faster, uses about 1.4% less VRAM and has pretty similar loss and grad norm characteristics.

I also attempted to use the updated triton RMS Norm over the CUDA implementation of RMS norm from flash attn and made things slightly worse.

Screenshot 2024-07-15 at 8 57 59 AM