jax-ml / jax-triton

jax-triton contains integrations between JAX and OpenAI Triton
Apache License 2.0
329 stars 32 forks source link

Import error encountered in jax_triton #264

Open egg5154 opened 6 months ago

egg5154 commented 6 months ago

Hello, I was running jax_triton on A100 and CUDA 12.2, but when I run the command python -c 'import jax_triton as jt', error occurs:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/lustre/grp/gyqlab/liyh/debugs/jax-triton/jax_triton/__init__.py", line 19, in <module>
    from jax_triton.triton_lib import triton_call
  File "/lustre/grp/gyqlab/liyh/debugs/jax-triton/jax_triton/triton_lib.py", line 50, in <module>
    from triton._C.libtriton import ir as tl_ir
ImportError: cannot import name 'ir' from 'triton._C.libtriton' (/lustre/grp/gyqlab/liyh/anaconda3/envs/jax_triton3/lib/python3.10/site-packages/triton/_C/libtriton.so)

My jax_triton was installed following https://github.com/google/jax/issues/18603

superbobry commented 6 months ago

Hi @egg5154, which version of jax_triton and triton do you have?

egg5154 commented 6 months ago

Hi @egg5154, which version of jax_triton and triton do you have?

Hello, jax_triton is 0.1.4 and triton(triton-nightly) is 2.1.0.post20231216005823

superbobry commented 6 months ago

I suspect you might need the main version of jax_triton. There was a number of refactorings in the Triton Python APIs, and the main version should be up to date.

egg5154 commented 6 months ago

I suspect you might need the main version of jax_triton. There was a number of refactorings in the Triton Python APIs, and the main version should be up to date.

Hello @superbobry , I changed to the main version but the error occurs again. The original import code causes error in jax_triton is from triton._C.libtriton import ir as tl_ir, while using from triton._C.libtriton.triton import ir as tl_ir instead will bypass the error. I thought that triton's version may be the key point, so I tried to change the version of triton to 2.0.0/2.1.0/2.2.0 but there's no help.

superbobry commented 6 months ago

Ouch, sorry you have to deal with this. It is indeed quite tricky to find a working jax, jax_triton and triton combination.

If you are open to using Pallas instead of Triton directly, google/jax#19890 changed how Pallas-produced Triton kernels are compiled. We no longer need neither jax_triton, nor triton, as long as you install the nightly jaxlib and jax (starting tomorrow).

egg5154 commented 6 months ago

Ouch, sorry you have to deal with this. It is indeed quite tricky to find a working jax, jax_triton and triton combination.

If you are open to using Pallas instead of Triton directly, google/jax#19890 changed how Pallas-produced Triton kernels are compiled. We no longer need neither jax_triton, nor triton, as long as you install the nightly jaxlib and jax (starting tomorrow).

Thanks! Actually I want to use flash-attention in jax, it seems that Pallas can be used instead of Triton?

superbobry commented 6 months ago

Yeah, you could use Pallas, which would lower to Triton on GPU without using Triton Python APIs.