Closed HPPinata closed 2 weeks ago
This might still need a bit more work, hold on for now
@tazlin This should now be fine for a merge. Performance is as expected, memory usage slightly better than the present implementation and conda performance is roughly in line with the docker version. Stability is also improved, I'm getting a <1% process recovery rate.
Once support is merged upstream this will get another minor rework, but that might be months off.
A newer (and less janky) version of flash_attn.
A bit more testing is required around changing the PyTorch version to 2.5.0 without breaking older setups.
Potential improvements:
FLASH_ATTENTION_USE_TRITON_ROCM=FALSE
(currently it's not build on old cards no matter what, but has to be set to true to avoid errors on compatible GPUs)