Closed cdagher closed 1 year ago
When running fori_loop
under jit
, the shapes of input arrays must match the shapes of output arrays. From the error message:
the input carry component loop_carry[1][3].positions has type float32[0] but the corresponding output carry component has type float32[10,3], so the shapes do not match
It looks like loop_carry[1][3]
is the variable you call ball
, and on input ball.positions
has shape (0,)
and on output ball.positions
has shape (10, 3)
.
The way to fix this is to ensure that the input arrays have the same shape as the output arrays. I would look for where you're initializing ball
in your code, and make sure it's initialized with the same shape arrays as you expect on output.
Thanks @jakevdp! I hadn't thought to look at ball.positions. I changed the array in BouncingBall
to have a pre-allocated size and now it works.
Description
There appears to be an issue with
jax.lax.fori_loop
. When I try to use this function, I get the following exception:"the input carry component loop_carry[1][3].positions has type float32[0] but the corresponding output carry component has type float32[10,3], so the shapes do not match"
The code producing this error is the following:
A similar exception is being thrown for velocities and constraints.
In this function,
controller
extendsequinox.Module
, andBouncingBall
is aflax.struct.dataclass
that wraps a BraxSystem
with some other arrays for state information at different timesteps.When I disable jit compiling using
the function runs without issues, but when it is JIT compiled it throws these exceptions.
What jax/jaxlib version are you using?
jax v0.4.14, jaxlib 0.4.14
Which accelerator(s) are you using?
CPU
Additional system info
Python 3.10.12, Ubuntu 22.04, Intel Xeon E3-1230 V2
NVIDIA GPU info
No response