Open egg5154 opened 9 months ago
Hi @egg5154, which version of jax_triton
and triton
do you have?
Hi @egg5154, which version of
jax_triton
andtriton
do you have?
Hello, jax_triton is 0.1.4 and triton(triton-nightly) is 2.1.0.post20231216005823
I suspect you might need the main
version of jax_triton
. There was a number of refactorings in the Triton Python APIs, and the main
version should be up to date.
I suspect you might need the
main
version ofjax_triton
. There was a number of refactorings in the Triton Python APIs, and themain
version should be up to date.
Hello @superbobry , I changed to the main
version but the error occurs again.
The original import code causes error in jax_triton
is from triton._C.libtriton import ir as tl_ir
, while using from triton._C.libtriton.triton import ir as tl_ir
instead will bypass the error.
I thought that triton
's version may be the key point, so I tried to change the version of triton
to 2.0.0/2.1.0/2.2.0 but there's no help.
Ouch, sorry you have to deal with this. It is indeed quite tricky to find a working jax
, jax_triton
and triton
combination.
If you are open to using Pallas instead of Triton directly, google/jax#19890 changed how Pallas-produced Triton kernels are compiled. We no longer need neither jax_triton
, nor triton
, as long as you install the nightly jaxlib
and jax
(starting tomorrow).
Ouch, sorry you have to deal with this. It is indeed quite tricky to find a working
jax
,jax_triton
andtriton
combination.If you are open to using Pallas instead of Triton directly, google/jax#19890 changed how Pallas-produced Triton kernels are compiled. We no longer need neither
jax_triton
, nortriton
, as long as you install the nightlyjaxlib
andjax
(starting tomorrow).
Thanks! Actually I want to use flash-attention in jax
, it seems that Pallas can be used instead of Triton?
Yeah, you could use Pallas, which would lower to Triton on GPU without using Triton Python APIs.
Hello, I was running jax_triton on A100 and CUDA 12.2, but when I run the command
python -c 'import jax_triton as jt'
, error occurs:My jax_triton was installed following https://github.com/google/jax/issues/18603