sail-sg / envpool

C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.
https://envpool.readthedocs.io
Apache License 2.0
1.1k stars 100 forks source link

Update xla to use mlir rather than backend-specific-translations #314

Open pseudo-rnd-thoughts opened 3 months ago

pseudo-rnd-thoughts commented 3 months ago

Description

Fixes https://github.com/sail-sg/envpool/issues/313

Motivation and Context

EnvPool XLA doesn't work with Jax 0.4.29+

Types of changes

What types of changes does your code introduce? Put an x in all the boxes that apply:

pseudo-rnd-thoughts commented 3 months ago

@Trinkle23897 The linting error is unrelated to the PR

pseudo-rnd-thoughts commented 3 months ago

Testing this throws errors however there appears to be limited documentation on what the suggested change should be

JagerHoHo commented 1 month ago

After implementing the suggested changes, I encountered an error when running the XLA example with JAX 0.4.34(latest). The error reads:

TypeError: CustomCallWithLayout(): incompatible function arguments. The following argument types are supported:
    1. CustomCallWithLayout(builder: jaxlib.xla_extension.XlaBuilder, call_target_name: bytes, operands: Span[jaxlib.xla_extension.XlaOp], shape_with_layout: jaxlib.xla_extension.Shape, operand_shapes_with_layout: Span[jaxlib.xla_extension.Shape], opaque: bytes = b'', has_side_effect: bool = False, schedule: jaxlib.xla_extension.ops.CustomCallSchedule = CustomCallSchedule.SCHEDULE_NONE, api_version: jaxlib.xla_extension.ops.CustomCallApiVersion = CustomCallApiVersion.API_VERSION_ORIGINAL) -> jaxlib.xla_extension.XlaOp
Invoked with types: jax._src.interpreters.mlir.LoweringRuleContext, bytes, kwargs = { operands: tuple, operand_shapes_with_layout: tuple, shape_with_layout: jaxlib.xla_extension.Shape, opaque: bytes, has_side_effect: bool }

I installed envpool using pip install envpool and manually applied the changes as instructed. It seems there may be an issue with how the CustomCallWithLayout is invoked in the current context with the latest JAX.

https://github.com/sail-sg/envpool/blob/f411fc26c8999ba5b9c39974344903b164486d1a/envpool/python/xla_template.py#L74-L88

Would appreciate any guidance on resolving this.