jax-ml / jax-triton

jax-triton contains integrations between JAX and OpenAI Triton
Apache License 2.0
350 stars 40 forks source link

cannot import name 'ShapedArrayRef' due to deprecation in jax._src.state #271

Open dtunai opened 8 months ago

dtunai commented 8 months ago

Hello. When importing the latest packaged version of Jax-Triton (jax-triton==0.1.3), it returns an error as follows:

[/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/lowering.py](https://localhost:8080/#) in <module>
     33 from jax._src.state import primitives as sp
     34 from jax._src.state import discharge
---> 35 from jax._src.state import ShapedArrayRef
     36 from jax_triton.triton_lib import get_triton_type
     37 import jax.numpy as jnp

ImportError: cannot import name 'ShapedArrayRef' from 'jax._src.state' (/usr/local/lib/python3.10/dist-packages/jax/_src/state/__init__.py)

It appears that the deprecation of this method in Jax version 0.4.12 breaks the package initialization, as far as I can test.