sail-sg / envpool

C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.
https://envpool.readthedocs.io
Apache License 2.0
1.09k stars 100 forks source link

[BUG] XLA is incompatible with jax 0.4.29 #313

Open pseudo-rnd-thoughts opened 3 months ago

pseudo-rnd-thoughts commented 3 months ago

Describe the bug

jax.interpreters.xla.backend_specific_translations is deprecated in jax v0.4.29 https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-29-june-10-2024

This causes the following error when running in xla mode AttributeError: jax.interpreters.xla.backend_specific_translations is deprecated. Register custom primitives via jax.interpreters.mlir instead.

To Reproduce

import envpool

env = envpool.make("Breakout-v5")
env.xla()
    handle, recv, send, step = env.xla()
  File "/lib/python3.10/site-packages/envpool/python/lax.py", line 30, in xla
    _handle, _recv, _send = make_xla(self)
  File "/lib/python3.10/site-packages/envpool/python/xla_template.py", line 124, in make_xla
    methods.append(_make_xla_function(obj, handle, name, specs, capsules))
  File "/lib/python3.10/site-packages/envpool/python/xla_template.py", line 94, in _make_xla_function
    xla.backend_specific_translations["cpu"][prim] = partial(
  File "/lib/python3.10/site-packages/jax/_src/deprecations.py", line 52, in getattr
    raise AttributeError(message)
AttributeError: jax.interpreters.xla.backend_specific_translations is deprecated. Register custom primitives via jax.interpreters.mlir instead.

Expected behavior

xla function is created

System info

import envpool, numpy, sys
print(envpool.__version__, numpy.__version__, sys.version, sys.platform)
> 0.8.4 1.26.4 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] linux
print(jax.__version__)
> 0.4.29

Reason and Possible fixes

According to the error message then we should use Register custom primitives via jax.interpreters.mlir instead

Checklist

pseudo-rnd-thoughts commented 3 months ago

Closed by accident

Trinkle23897 commented 2 months ago

sorry about that, I'll try fixing the ci if I have time