patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.32k stars 121 forks source link

Solver for very stiff Neural Ordinary Differential Equations #319

Open mayor-slash opened 9 months ago

mayor-slash commented 9 months ago

I am currently implementing a physically constrained neural ordinary differential equation. The application is reaction kinetics so you already know it's going to be a very stiff system. The governing differential equation is a function of the state Y, the output of two neural networks $M{1}$ and $M{2}$ with Y as input (equinox models) and some constant parameters.

$$ \frac{dY}{dt} = f \left(Y, M{1}(Y), M{2}(Y), C\right) $$

The two neural networks are pre-trained to produce reasonable outputs. After playing around with atol, rtol, and max_steps of the nonlinear_solver I can solve the equation with the ImplicitEuler method in Diffrax. I can also compare the result y(t) with my target values, compute a loss value, and get the gradient of the loss with respect to the network parameters.

After a few iterations, the diffrax solver throws an error that the implicit method diverges. So I checked if the updates to the neural networks produced some Nans/Infs which was not the case. With the same ode_fun and networks the scipy.integrate.solve_ivp BDF-method works just fine.

When setting the throw=False in diffeqsolve I can see that the solver fails at timesteps with very large gradients or almost zero gradients. Due to the nature of the problem, gradients can span many orders of magnitude.

I have seen that the implementation of the BDF Solver is planed #8 . Maybe that solver is more suitable, but then again I don't have a math background. My question is, should I wait for that implementation, adjust parameters of diffeqsolve further or is there another way around this?

64bit precision is activated via: config.update("jax_enable_x64", True)

I am thankful for any recommendations!

patrick-kidger commented 9 months ago

I'd recommend trying Kvaerno{3,4,5} instead of ImplicitEuler. The latter isn't a very good numerical method. Also try adjusting the step size controller. Something like stepsize_controller=PIDController(pcoeff=0.3, icoeff=0.4, rtol=1e-8, atol=1e-8) would be typical.

mayor-slash commented 9 months ago

Thank you for the quick response. I have been using the ImplicitEuler method since Kvarno5 didn't work for me at all. I testet 3 and 4 now and for some reason 3 works, 4 doesn't. Still I need about 1e5 to 1e6 steps to solve. Compared to the scipy BDF implementation only needing 1000-5000.

But since I am still new to the jax framework it might be caused by some unclean implementation of mine.

patrick-kidger commented 9 months ago

Gotcha. FWIW, the Kvaerno family of solvers are usually more robust than BDF, and computationally cheaper to work with. Generally speaking multistep solvers like BDF are best for problems with expensive vector fields, e.g. semidiscretised PDEs.

If you're able to construct a small (<100 lines) MWE comparing the two, then we can check whether Diffrax is doing everything it can do here.

mayor-slash commented 8 months ago

I would guess then, that my problem falls under that category of expensive vector fields. I have build a MWE which got as small as possible. In my real implementation, the neural networks are pretrained to produce certain outputs. However,. I can't squeeze this procces in the MWE. I hope loading the network parameters from the eqx-files is fine:

import jax.numpy as jnp
import jax
import diffrax
import equinox as eqx
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
from jax import config
config.update("jax_enable_x64", True)

def Dynamic_Wrapper(t, state, k_model, t_model, reactants, sm):
    return Dynamic(t,state,(k_model, t_model, reactants, sm))
@jax.jit
def Dynamic(t: jnp.float64,                         
            state: jnp.array,                        
            params: tuple) -> jnp.array: 

    k_model, t_model, reactants, sm = params     
    T = state[-1]                                   
    rezi_T = jnp.expand_dims(jnp.divide(1000,T),0)  
    N = jnp.clip(state[:-1], a_min=1e-40)
    N_log = jnp.divide(jnp.log10(N),-100.)
    k_input = jnp.concatenate([N_log, rezi_T],0)

    ks = jnp.power(10.,k_model(k_input))
    N = jnp.expand_dims(N,0)
    eff_N = jnp.power(N, reactants)
    q  = jnp.prod(eff_N, axis=-1)
    qk =  jnp.multiply(q,ks)
    qk = jnp.expand_dims(qk,-1)
    dN1 = jnp.multiply(sm, qk)
    dN = jnp.sum(dN1,0)

    us,cp = jnp.reshape(t_model(rezi_T), (2,-1))
    cv = jnp.multiply(cp, 1e3) # J/kmol/K
    us = jnp.multiply(us, 1e7)  # J/kmol
    cv = jnp.sum(jnp.multiply(cv,N)) #J/K 
    dU = jnp.sum(jnp.multiply(us,dN)) # J
    dT = jnp.array([-dU/cv])
    grad = jnp.concatenate([dN,dT])
    return grad

@eqx.filter_jit
def solve_ODE(  func: jax.jit,
                start: jnp.array,
                ts: jnp.array,
                params: tuple,
                throw=True, steps=False):
    newton = diffrax.NewtonNonlinearSolver(max_steps=10)
    solver = diffrax.Kvaerno3(nonlinear_solver=newton)
    if steps:
        saveat = diffrax.SaveAt(steps=True)
    else:
        saveat = diffrax.SaveAt(ts=ts)
    solution = diffrax.diffeqsolve(
        diffrax.ODETerm(func),
        solver,
        t0 = ts[0],
        t1 = ts[-1],
        dt0 = None,
        y0 = start, 
        args = params,
        max_steps = int(1e6),
        throw = throw,
        adjoint= diffrax.RecursiveCheckpointAdjoint(int(1e6)),
        stepsize_controller=diffrax.PIDController(pcoeff=0.3, icoeff=0.4, dcoeff=0, rtol=1e-12, atol=1e-10 ),
        saveat=saveat
    )
    return solution.ys,solution.ts

class NN_k(eqx.Module):
    layers: list
    def __init__(self, inp, out, hidden, key):
        keys = jax.random.split(key, len(hidden))
        hidden = [inp,*hidden, out]
        layers = [(hidden[i],hidden[i+1]) for i in range(len(hidden)-1)]
        self.layers = [ eqx.nn.Linear(h[0], h[1], key=keys[i]) for i,h in enumerate(layers)]
    @jax.jit
    def __call__(self,x):
        for l in self.layers:
            x = l(x)
            x = jnp.tanh(x)
        return x*50

if __name__ == "__main__":
    y0 = jnp.array([1e-40, 0.1976306303518972, 1e-40, 1e-40, 1e-40, 0.0988153151759486, 0.3717338047095211, 900.0])
    ts = jnp.linspace(0.0, 0.08351431971177574, 100 )
    sm = jnp.array([    [2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0     ],   
                        [-1.0, 0.0, 1.0, 1.0, 0.0, -1.0, 0.0    ],
                        [1.0, -1.0, 0.0, -1.0, 1.0, 0.0, 0.0    ],
                        [1.0, -1.0, -1.0, 1.0, 0.0, 0.0, 0.0    ],
                        [-2.0, 1.0, -0.0, -0.0, -0.0, -0.0, -0.0],
                        [1.0, -0.0, -1.0, -1.0, -0.0, 1.0, -0.0 ],
                        [-1.0, 1.0, -0.0, 1.0, -1.0, -0.0, -0.0 ],
                        [-1.0, 1.0, 1.0, -1.0, -0.0, -0.0, -0.0 ]])
    reactants = -jnp.clip(sm, a_max=0.0)
    key = jax.random.PRNGKey(42)
    k_model =NN_k(8, 8, [40], key)
    t_model = NN_k(1, 14, [30,60], key)
    params = (k_model,t_model,reactants,sm)

    # ys,ts = solve_ODE(Dynamic, y0, ts, params=params)
    # for y in ys:
    #     plt.plot(ts,y)
    # plt.yscale("log")
    # plt.show()

    sol = solve_ivp(Dynamic_Wrapper, (ts[0], ts[-1]), y0, method="BDF", args=params, rtol=1e-12, atol=1e-10)
    for y in sol.y:
        plt.plot(sol.t, y)
    plt.yscale("log")
    plt.show()

The last two blocks can be commented/uncommented to compare the solvers.

Edit: I also tested this jax based BDF solver and it runs quickly while producing the same results as scipys solve_ivp. However, in comparison to diffrax it was probably not intended to be differentiable. So to this point I can't compute gradients for my neural network training with this. I will try to dig deeper into this

Edit 2: I removed the equinox files and just used the random intitializiation.

patrick-kidger commented 8 months ago

Hi -- best to avoid loading from untrusted files, as the format isn't designed for security. I'm afraid I'd need a MWE without loading from an external file. (Would just the random weights at initialisation work?)

mayor-slash commented 8 months ago

Hi, that makes sense. I removed them and now its just the random initialisation. This makes not to much sense from a physical point of view but I tested it and it still shows that solve_ivp solves it fast while diffrax taps out after 1e6 steps. FYI: I pretrain the two ANNs in a classic feed forward style on some data. Then put them in my model shown above and train it against some other data. At the start of that second training the solver can still handle it, after some iterations the solver taps out.

patrick-kidger commented 8 months ago

Thanks! I've just tried running your script but it looks like scipy fails as well -- sol.success is False and sol.status is -1. As such I suspect the problem is with the ODE you're trying to solve.

I've not looked closely at the vector field you're using, but common problems here are things like logs (which are undefined for negative numbers) and divisions (which produce very large values for divisors close to zero), or very large values (that produce rapidly changing behaviour), or powers (which mean that the vector field may fail the conditions of the Picard existence theorem and the solution does not exist even theoretically).