Open pseudo-rnd-thoughts opened 3 months ago
@Trinkle23897 The linting error is unrelated to the PR
Testing this throws errors however there appears to be limited documentation on what the suggested change should be
After implementing the suggested changes, I encountered an error when running the XLA example with JAX 0.4.34(latest). The error reads:
TypeError: CustomCallWithLayout(): incompatible function arguments. The following argument types are supported:
1. CustomCallWithLayout(builder: jaxlib.xla_extension.XlaBuilder, call_target_name: bytes, operands: Span[jaxlib.xla_extension.XlaOp], shape_with_layout: jaxlib.xla_extension.Shape, operand_shapes_with_layout: Span[jaxlib.xla_extension.Shape], opaque: bytes = b'', has_side_effect: bool = False, schedule: jaxlib.xla_extension.ops.CustomCallSchedule = CustomCallSchedule.SCHEDULE_NONE, api_version: jaxlib.xla_extension.ops.CustomCallApiVersion = CustomCallApiVersion.API_VERSION_ORIGINAL) -> jaxlib.xla_extension.XlaOp
Invoked with types: jax._src.interpreters.mlir.LoweringRuleContext, bytes, kwargs = { operands: tuple, operand_shapes_with_layout: tuple, shape_with_layout: jaxlib.xla_extension.Shape, opaque: bytes, has_side_effect: bool }
I installed envpool
using pip install envpool
and manually applied the changes as instructed. It seems there may be an issue with how the CustomCallWithLayout
is invoked in the current context with the latest JAX
.
Would appreciate any guidance on resolving this.
Description
Fixes https://github.com/sail-sg/envpool/issues/313
Motivation and Context
EnvPool XLA doesn't work with Jax 0.4.29+
Types of changes
What types of changes does your code introduce? Put an
x
in all the boxes that apply: