patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

Slow JIT compilation when doing teacher forcing #599

Open itsakk opened 11 months ago

itsakk commented 11 months ago

Hello Patrick,

I recently tried to move from torch to jax/equinox (thanks a lot for your contributions both for diffrax and equinox) and i tried to adapt all my code written in torch to jax/equinox.

I am having trouble trying to implement teacher forcing for my NeuralODE. I was used to do that with pytorch as I remarked it was helping the network to not diverge from the ground truth trajectory.

The issue is that my code takes a lot of time to compile and thus training the network takes way more time than when using torch. It is probably due to the for loop, if my understanding is correct from the different comments you made regarding that type of errors.

I have tried to use jax.lax.scan but I don't find a proper way to do that when doing teacher forcing.

Here is my code, you can consider that my network is just a NODE with a MLP network:

@eqx.filter_value_and_grad
def compute_loss(diff_model, static_model, batch, epsilon, min_op, lambda_0):
    model = eqx.combine(diff_model, static_model)
    outputs = teacher_forcing(model, batch, epsilon)
    y = batch['states']
    mse_loss = jnp.mean((outputs - y) ** 2)
    return mse_loss

def teacher_forcing(model, batch, epsilon):
    t = batch['t']
    y = batch['states']
    env = batch['env']

    if epsilon < 1e-3:
        epsilon = 0

    if epsilon == 0:
        res = jax.vmap(model, in_axes = (0, 0, 0, None))(y, t, env, 0)
    else:
        eval_points = np.random.random(len(t[0])) < epsilon
        eval_points[-1] = False
        eval_points = eval_points[1:]
        start_i, end_i = 0, None
        res = []
        for i, eval_point in enumerate(eval_points):
            if eval_point:
                end_i = i + 1
                t_seg = t[:, start_i:end_i + 1]
                res_seg = jax.vmap(model, in_axes = (0, 0, 0, None))(y, t_seg, env, start_i)
                if len(res) == 0:
                    res.append(res_seg)
                else:
                    res.append(res_seg[:, 1:, :])
                start_i = end_i
        t_seg = t[:, start_i:]
        res_seg = jax.vmap(model, in_axes = (0, 0, 0, None))(y, t_seg, env, start_i)
        if len(res) == 0:
            res.append(res_seg)
        else:
            res.append(res_seg[:, 1:, :])
        res = jnp.concatenate(res, axis=1)
    return jnp.moveaxis(res, 1, 2)

@eqx.filter_jit
def train_step(model, filter_spec, batch, epsilon, optim, opt_state, min_op, lambda_0):
    diff_model, static_model = eqx.partition(model, filter_spec)
    loss, grads = compute_loss(diff_model, static_model, batch, epsilon, min_op, lambda_0)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

Any suggestions to increase efficiency are greatly appreciated. Thanks a lot!

knyazer commented 11 months ago

I believe that you are right about the reason: it seems most likely that slow compilation time is caused by JAX unrolling the for i, eval_point ... loop.

About your code: I think you were on the right track with using lax.scan (an alternative would be to use lax.for_i_loop but both these loop primitives are essentially the same). In addition to simply using it instead of the for-loop, you should use lax.cond instead of if-else conditionals, and jax.numpy.take to slice the arrays. I would suggest for you to try to replace all the Python list operations with some jax.numpy operations. The reason for all of this is so that JAX won't have a static, trace-time dependency on the values, thus allowing you to use lax.scan instead of the for-loop. There are a lot of changes to make, so good luck!

I strongly recommend you to read through The Sharp Bits, which is one of the most useful resources when starting with JAX. And, if you have time, try to read through all the tutorials on the JAX website under the 'Getting Started' section: they contain a lot of crucial information that allows to debug this type of problems faster. You also might find this answer interesting, since it gives a bit more explanations about why it takes a lot of time to compile the for-loop.