Closed jon-chuang closed 2 weeks ago
Thank you for this change! This is really helpful! The main issue is that we mirror Triton internally at Google and haven't updated to latest triton yet. JAX-Triton is kept in sync with the Triton in Google.
When we update Triton (which we should do soon), we can land your PR.
Currently, Triton dep is 1.5 months old.
While one should not attempt to update regularly, I think its about time to upgrade due to breaking API changes as well as benefit from Triton improvements (see for instance in https://github.com/google/jax/pull/17328#issuecomment-1705010065 - about 25% speedup in flash attention for the given tile size due to fixes on Triton side)