Hello. When importing the latest packaged version of Jax-Triton (jax-triton==0.1.3), it returns an error as follows:
[/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/lowering.py](https://localhost:8080/#) in <module>
33 from jax._src.state import primitives as sp
34 from jax._src.state import discharge
---> 35 from jax._src.state import ShapedArrayRef
36 from jax_triton.triton_lib import get_triton_type
37 import jax.numpy as jnp
ImportError: cannot import name 'ShapedArrayRef' from 'jax._src.state' (/usr/local/lib/python3.10/dist-packages/jax/_src/state/__init__.py)
It appears that the deprecation of this method in Jax version 0.4.12 breaks the package initialization, as far as I can test.
Hello. When importing the latest packaged version of Jax-Triton (jax-triton==0.1.3), it returns an error as follows:
It appears that the deprecation of this method in Jax version 0.4.12 breaks the package initialization, as far as I can test.