Open pseudo-rnd-thoughts opened 3 months ago
jax.interpreters.xla.backend_specific_translations is deprecated in jax v0.4.29 https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-29-june-10-2024
jax.interpreters.xla.backend_specific_translations
This causes the following error when running in xla mode AttributeError: jax.interpreters.xla.backend_specific_translations is deprecated. Register custom primitives via jax.interpreters.mlir instead.
AttributeError: jax.interpreters.xla.backend_specific_translations is deprecated. Register custom primitives via jax.interpreters.mlir instead.
import envpool env = envpool.make("Breakout-v5") env.xla()
handle, recv, send, step = env.xla() File "/lib/python3.10/site-packages/envpool/python/lax.py", line 30, in xla _handle, _recv, _send = make_xla(self) File "/lib/python3.10/site-packages/envpool/python/xla_template.py", line 124, in make_xla methods.append(_make_xla_function(obj, handle, name, specs, capsules)) File "/lib/python3.10/site-packages/envpool/python/xla_template.py", line 94, in _make_xla_function xla.backend_specific_translations["cpu"][prim] = partial( File "/lib/python3.10/site-packages/jax/_src/deprecations.py", line 52, in getattr raise AttributeError(message) AttributeError: jax.interpreters.xla.backend_specific_translations is deprecated. Register custom primitives via jax.interpreters.mlir instead.
xla function is created
import envpool, numpy, sys print(envpool.__version__, numpy.__version__, sys.version, sys.platform) > 0.8.4 1.26.4 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] linux print(jax.__version__) > 0.4.29
According to the error message then we should use Register custom primitives via jax.interpreters.mlir instead
Register custom primitives via jax.interpreters.mlir instead
Closed by accident
sorry about that, I'll try fixing the ci if I have time
Describe the bug
jax.interpreters.xla.backend_specific_translations
is deprecated in jax v0.4.29 https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-29-june-10-2024This causes the following error when running in xla mode
AttributeError: jax.interpreters.xla.backend_specific_translations is deprecated. Register custom primitives via jax.interpreters.mlir instead.
To Reproduce
Expected behavior
xla function is created
System info
Reason and Possible fixes
According to the error message then we should use
Register custom primitives via jax.interpreters.mlir instead
Checklist