The attention monkey patch we have for llama is pretty old at this point and having to maintain it is a pain. Swapping to the updated unpad patch for flash attention, and did a slight refactor to continue to support the cross entropy loss and rms norm patches.
As we can see, it's slightly faster, uses about 1.4% less VRAM and has pretty similar loss and grad norm characteristics.
I also attempted to use the updated triton RMS Norm over the CUDA implementation of RMS norm from flash attn and made things slightly worse.
The attention monkey patch we have for llama is pretty old at this point and having to maintain it is a pain. Swapping to the updated unpad patch for flash attention, and did a slight refactor to continue to support the cross entropy loss and rms norm patches.
As we can see, it's slightly faster, uses about 1.4% less VRAM and has pretty similar loss and grad norm characteristics.
I also attempted to use the updated triton RMS Norm over the CUDA implementation of RMS norm from flash attn and made things slightly worse.