jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

Unexpected exception from jax.lax.fori_loop #17629

Closed cdagher closed 1 year ago

cdagher commented 1 year ago

Description

There appears to be an issue with jax.lax.fori_loop. When I try to use this function, I get the following exception:

"the input carry component loop_carry[1][3].positions has type float32[0] but the corresponding output carry component has type float32[10,3], so the shapes do not match"

The code producing this error is the following:

@partial(jax.jit, static_argnames=('targetForce', 'timesteps')
def loss(model: controller, ball: BouncingBall, targetForce: float = 1.0, timesteps: int = 10):

    positions = jp.array([[0]*3]*timesteps, dtype=jp.float32)
    velocities = jp.array([[0]*6]*timesteps, dtype=jp.float32)
    constraints = jp.array([[0]*6]*timesteps, dtype=jp.float32)
    carry_i = (positions, velocities, constraints, ball, model)

    def step(i: int, carry: tuple):

        positions_s, velocities_s, constraints_s, ball_s, model_s = carry

        positions_s = positions_s.at[i,:].add(ball_s.state.x.pos[0])
        velocities_s = velocities_s.at[i,:].add(ball_s.state.qd)
        constraints_s = constraints_s.at[i,:].add(ball_s.state.qf_constraint)

        x = jp.array([ball_s.state.x.pos[0][2], ball_s.state.qd[2]])
        force = model_s(x.transpose())

        newstate = pipeline.step(ball_s.system, ball_s.state, force)
        ball_s = ball_s.create(ball_s.system, newstate, positions_s, velocities_s, ball_s.contacts, constraints_s, model_s)

        newStuff = (positions_s, velocities_s, constraints_s, ball_s, model_s)

        return newStuff

    positions, velocities, constraints, ball, model = jax.lax.fori_loop(0, timesteps, step, carry_i)

    states = (positions, velocities, constraints)

    loss_value = jp.linalg.norm(constraints[:,2] - jp.array([targetForce]*timesteps))

    return loss_value, states

A similar exception is being thrown for velocities and constraints.

In this function, controller extends equinox.Module, and BouncingBall is a flax.struct.dataclass that wraps a Brax System with some other arrays for state information at different timesteps.

When I disable jit compiling using

from jax.config import config
config.update('jax_disable_jit', True)

the function runs without issues, but when it is JIT compiled it throws these exceptions.

What jax/jaxlib version are you using?

jax v0.4.14, jaxlib 0.4.14

Which accelerator(s) are you using?

CPU

Additional system info

Python 3.10.12, Ubuntu 22.04, Intel Xeon E3-1230 V2

NVIDIA GPU info

No response

jakevdp commented 1 year ago

When running fori_loop under jit, the shapes of input arrays must match the shapes of output arrays. From the error message:

the input carry component loop_carry[1][3].positions has type float32[0] but the corresponding output carry component has type float32[10,3], so the shapes do not match

It looks like loop_carry[1][3] is the variable you call ball, and on input ball.positions has shape (0,) and on output ball.positions has shape (10, 3).

The way to fix this is to ensure that the input arrays have the same shape as the output arrays. I would look for where you're initializing ball in your code, and make sure it's initialized with the same shape arrays as you expect on output.

cdagher commented 1 year ago

Thanks @jakevdp! I hadn't thought to look at ball.positions. I changed the array in BouncingBall to have a pre-allocated size and now it works.