Uses flash attention v2. Works for any sequence length.
Flash attn v2 from pytorch for masked attention.
Custom flash attn v2 kernel for bias (triton-based, currently block size decided by autotuner, might need improvements to avoid recompilation).
TF32 acceleration disabled by default, might need to check compute architecture and selectively enable it.
WARNING! Pls install the custom fork from trident provided below as a zip file. You might need to install triton-nightly afterwards (maybe also nvtx) to get it working (see triton repo for instructions). Check with ipython; import trident as td
Uses flash attention v2. Works for any sequence length.
WARNING! Pls install the custom fork from trident provided below as a zip file. You might need to install
triton-nightly
afterwards (maybe alsonvtx
) to get it working (see triton repo for instructions). Check withipython; import trident as td
trident.zip