patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.4k stars 123 forks source link

XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`. #268

Open AshCher51 opened 1 year ago

AshCher51 commented 1 year ago

Hi Patrick!

I'm training a simple Neural CDE model for use in time series regression, and I keep coming across the same error again and again. I know this error (RuntimeError: The maximum number of solver steps was reached. Try increasing max_steps) has come up in several issues, but I was hoping I could get some help with resolving this as I was unable to pinpoint why it's arising. I'm making use of the exact same Func and NeuralCDE classes as provided in the Neural CDE example in the Diffrax documentation; only things I changed is making use of the following loss and make_step functions:

#@eqx.filter_jit
def loss(model, ti, label_i, coeff_i):
    pred = jax.vmap(model)(ti, coeff_i)
    # huber loss
    errors = pred - label_i
    abs_errors = jnp.abs(errors)
    quadratic = jnp.minimum(abs_errors, 1)
    linear = abs_errors - quadratic
    rmse = jnp.sqrt(jnp.mean(errors ** 2))
    return jnp.mean(0.5 * quadratic ** 2 + linear), rmse, abs_errors

grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)

#@eqx.filter_jit
def make_step(model, batch, opt_state):
    # ts, coeff_i, ys = jnp.array(batch['ts']), jnp.array(batch['coeffs']), jnp.array(batch['ys'])
    ts = jnp.array(batch['ts'], dtype=jnp.float64)
    coeff_i = jnp.array(batch['coeffs'], dtype=jnp.float64)
    ys = jnp.array(batch['ys'], dtype=jnp.float64)
    (huber, rmse, mae), grads = grad_loss(model, ts, ys, coeff_i)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return huber, rmse, mae, model, opt_state

and making use of a PyTorch dataset class with custom dumpy collate function to handle dictionaries and the NumPy data. Any idea of where this error might be coming from?

Here is a link to my full code in case this might not be sufficient to help figure out what's wrong: https://pastebin.com/L4BM0yT8

Any help with this would be greatly appreciated!

patrick-kidger commented 1 year ago

This may well be because your dataset is too oscillatory to handle within the default number of steps. You could try increasing diffeqsolve(..., max_steps=...) to some larger value. Alternatively, try running with sol = diffeqsolve(..., saveat=SaveAt(steps=True), throw=False) and then inspect sol.ts to see where steps are being made, and sol.stats to see how many steps are being rejected -- maybe you're encountering some numerically difficult behaviour.

AshCher51 commented 1 year ago

Thank you very much for the help!

I tried increasing max_steps, but I still ended up getting the same error: RuntimeError: The maximum number of solver steps was reached. Try increasing max_steps.

When I try printing out sol.ts and sol.stats, I get a really large output I'm not sure how to interpret; any help with this would be appreciated.

Here is the link to the output: https://pastebin.com/BwrE2cha.

I used this same dataset with torchcde and I was able to get the Neural CDE to work there (here is my code for that: https://pastebin.com/Mwm21zgk).

I would really appreciate your help with this. Looking forward to hearing from you!

patrick-kidger commented 1 year ago

I'm afraid that's much too large for me to parse! Work to create a MWE.