titu1994 / tfdiffeq

Tensorflow implementation of Ordinary Differential Equation Solvers with full GPU support
MIT License
218 stars 53 forks source link

Poor performance with a 'steady-state' population #6

Open Jhsmit opened 4 years ago

Jhsmit commented 4 years ago

Hi, thanks for making this lovely package.

I have a simple ODE model with the following code:


import tensorflow as tf
import matplotlib.pyplot as plt
from tfdiffeq import odeint

class Kinetics(tf.keras.Model):

    def __init__(self, k1, k2, k3):
        super().__init__()
        self.k1, self.k2, self.k3 = k1, k2, k3

    @tf.function
    def call(self, t, y):
        s0, s1, s2 = y[0], y[1], y[2]

        d_0 = - self.k1 * s0 + self.k2 * s1
        d_1 = -self.k2 * s1 - self.k3 * s1 + self.k1 * s0
        d_2 = self.k3 * s1

        return tf.stack([d_0, d_1, d_2])

with tf.device('/gpu:0'):
    tf.keras.backend.set_floatx('float64')
    k1 = 1.
    k2 = 2.
    k3 = 3.
    NUM_SAMPLES = 100

    t = tf.linspace(0., 10., num=NUM_SAMPLES)
    y_init = tf.constant([1.,0., 0.], dtype=tf.float64)

    func = Kinetics(k1, k2, k3)
    result = odeint(func, y_init, t)

plt.figure()
for r in tf.transpose(result):
    plt.plot(r)
plt.show()

This works fine and gives me the expected result directly. However, when I change the parameters to:

    k1 = 1.
    k2 = 2000.
    k3 = 3.

Where now I have a steady-state population, as the population of s1 is almost zero over the time interval The calculation now takes a very long time, while if I use scipy's odeint it is still very fast.

Any suggestions on why this is and what I can do to improve this?

I've tried tensorflow 2.1.0 and 2.2.0, adding @tf.function decorator, and @KuzMenachem 's fork version.

titu1994 commented 4 years ago

Scipy Odeint uses other solvers than dopri5 underneath the hood, written in Fortran and built to switch between still and non stiff solvers. So it's general performance is excellent. You can't take gradients from it however.

I suggest if you only want to compute the forward pass and are not interesting in taking gradients of the operation, then to use scipy solvers directly.

In case you do need gradients, I'd suggest trying different solvers, though dopri5 is the most stable one at the moment. I see the torchdiffeq has ported a few more solvers, I'll try getting them ported to tfdiffeq, but I don't think it'd going to have much of an effect on your speed issues.

Another thing to suggest, force device to CPU rather than GPU. You are computing just 3 variable statespace, the shuttling to and from GPU memory will be much more expensive than the actual computer on the gpu

Also, that fork looks interesting. Wonder if the author is interested in submitting a PR.

Jhsmit commented 4 years ago

Thank you for your swift and detailed reply. This field looks like its very dynamic and fast moving. Very interesting stuff, although since its not really my field it difficult to figure out w4hat the best implementation or approach is but what is clear to me is that these methods are very powerful which likely are quite useful for my application.

Basically, what I'm doing is using a tensorflow layers as more of a classical kind of 'curve fitting', with a function f(t, k1,k2,k3) where k1,k2,k3 are fit parameters (trainable weights) and the function value is the population of state s2 in the example above. So the call of this layer involves a forward pass of the ODE system. One of the reasons I'm using tensorflow for the fitting and not scipy minimizer is because the k values are not scalars but rather vectors of length up to ~1000. These are interdependent in the sense they they go into the ODE system independently but the result of

In my current implementation I use a steady-steady approximation of the rate equation such that I dont need to use ode integration. However, because i know that for certain values of k1,k2,k3 the steady state approximation introduces large errors, I would like to use the exact solution for the ODE model instead. This is why I was looking into using odeint using tensorflow.

I can see two approaches going forward with this, and although this might not be a support forum perhaps you could advice me on this:

  1. Use scipy.odeint in my layer's call to calculate the layer output given a set of k1,k2,k3. However, since they are 'vectors' rather than scalars, I would have to for loop through k to calculate all values. I don't know exactly how this works in terms of a tensorflow graph but it doesnt sound very efficient.
  2. Precalculate how k1, k2, k3, (t) maps to s2(t) and use tfp.math.batch_interp_regular_nd_grid to interpolate the values of the current iteration of the fit.
  3. This is sort of like 2. but more in line with this paper if I understood it correctly: train a network which then does the mapping from k-space to populations. The examples I looked at do this for a fixed rate parameters k1,k2,k3, but what I would need then is a network which describes the whole space and takes k1,k2,k3 (and time) as input. Is this possible? The idea would be then if I have a trained NN that solves the ODE equation, I can then load that network and forward-pass my data through when I do the fitting to find the k1,k2,k3 that best describes my data.
Jhsmit commented 4 years ago

I forgot to mention that I've tried the different solvers and also changing from GPU to CPU. I found another odeint over at Astro which, when using the 'rk4' method, seems to perform better.

titu1994 commented 4 years ago

With k2 = 2000., you are presenting stiff ODE. Solves like VODE and CVODE (inside scipy) can handle such stiff equations by performing a backward solve through the ODE rather than forward solvers like what this library offers.

Are you sure RK4 gives the correct solution? Heres the output with rk4 in this library (pass method='rk4' to odeint). Its plenty fast, but the solution is absolutely incorrect.

https://colab.research.google.com/drive/1QLp_A_bY1_19FUzUg0P93p-O3jgeIDJh#scrollTo=WfWigFLE724H

Whereas the actual solution from the Julia diffeq library (I trust this more than anything else really) -

using DifferentialEquations
using Plots

k1 = 1.0
k2 = 2.0
k3 = 3.0

function kinetics(du, u, p, t)
    s0, s1, s2 = u

    du[1] = -k1 * s0 + k2 * s1
    du[2] = -k2 * s1 - k3 * s1 + k1 * s0
    du[3] = k3 * s1

end

u0 = [1.0; 0.0; 0.0]
tspan = (0.0, 10.0)

prob = ODEProblem(kinetics, u0, tspan)
sol = solve(prob, reltol=1e-8, abstol=1e-8)

plot(sol)

k1 = 1.0
k2 = 2000.0
k3 = 3.0

function kinetics2(du, u, p, t)
    s0, s1, s2 = u

    du[1] = -k1 * s0 + k2 * s1
    du[2] = -k2 * s1 - k3 * s1 + k1 * s0
    du[3] = k3 * s1

end

u0 = [1.0; 0.0; 0.0]
tspan = (0.0, 10.0)

prob2 = ODEProblem(kinetics2, u0, tspan)
sol2 = solve(prob, reltol=1e-8, abstol=1e-8)

plot(sol2)

This solves both stiff and non stiff versions in under half a second after the first 1 time compile cost. So I would suggest going with this if speed and accuracy is your concern. You do need Julia though, and that's a non trivial step.

I don't quite understand how you propose to use a vector of K[N, 3] rather than K[3]. Do you mean to do a single forward pass through each Ki? Thats an exorbitant cost, in terms of gradients and forward passes required. Think of some way to vectorize this.

If you can explain what this matrix K[N, 3] represents, I can possibly be more help.

Finally, I do recommend learning Julia (if time is not a major constraint, its a pretty unique language) in order to do work with Differential Equations. There is perhaps no software stacks so uniquely suited and tailered to every class of Differential Equations as the DifferentialEquation suite of libraries available in Julia.

TimFelixBeyer commented 4 years ago

I'm not sure I understand your goal exactly, but here is my best try.

As @titu1994 mentioned, depending on your choice of k1,k2,k3 the system might become stiff. In that case explicit methods like RK4/5/DOPRI5. etc... don't work well. This will tend to happen if the eigenvalue ratio of your system becomes too large. The system can be written in matrix form:

A=[[-k1, k2, 0],      
  [k1, -k2-k3, 0],      
  [0, k3, 0]]

then the system looks like: s_dot = A * s

using k2=2000: A'=[[-1, 2000,0], [1, -2003,0], [0, 3, 0]] The eigenvalues of A' are: lambda0 = 0 lambda1 = -0.0014 lambda2 = -2004 The eigenvalue ratio is very large:2004/0.0014=1,400,000 This is characteristic of a stiff system.

You can also see this by looking at the number of steps an adaptive solver takes; In MATLAB: k1=1, k3=3 Using k2=2 the RK4/5 solver in MATLAB needs 109 timesteps, while the ode23s solver (which is implicit and designed for stiff equations) needs 43 steps. Using k2=2000, RK4/5 now uses 24173 (!!) time steps, while ode23s uses just 27 timesteps. The solvers in tfdiffeq behave similar to RK4/5 in MATLAB and therefore have bad performance on this problem.

May I suggest you look into the area of ODE-curve-fitting: https://www.mathworks.com/help/optim/ug/fit-differential-equation-ode.html (I think this is what you would like to do)

You can also do a form of this using tfdiffeq: e.g.

import tensorflow as tf
import matplotlib.pyplot as plt
from tfdiffeq import odeint

class Kinetics(tf.keras.Model):

    def __init__(self, k1, k2, k3):
        super().__init__()
        self.k1, self.k2, self.k3 = k1, k2, k3

    @tf.function
    def call(self, t, y):
        s0, s1, s2 = y[0], y[1], y[2]

        d_0 = - self.k1 * s0 + self.k2 * s1
        d_1 = - self.k2 * s1 - self.k3 * s1 + self.k1 * s0
        d_2 = self.k3 * s1

        return tf.stack([d_0, d_1, d_2])

with tf.device('/gpu:0'):
    tf.keras.backend.set_floatx('float64')
    NUM_SAMPLES = 100
    t = tf.linspace(0., 10., num=NUM_SAMPLES)
    y_init = tf.constant([1.,0., 0.], dtype=tf.float64)

    # Compute the reference trajectory
    ref_func = Kinetics(1., 2., 3.)
    ref_traj = odeint(ref_func, y_init, t)
    # Set up the model which we want to fit to the reference
    k1 = tf.Variable(1., trainable=True, dtype=tf.float64)
    k2 = tf.Variable(1., trainable=True, dtype=tf.float64)
    k3 = tf.Variable(1., trainable=True, dtype=tf.float64)
    func = Kinetics(k1, k2, k3)
    optimizer = tf.keras.optimizers.Adam(0.1, clipvalue=0.5)

    # Do 100 gradient descent steps
    for i in range(100):
        with tf.GradientTape() as tape:
            result = odeint(func, y_init, t)
            loss = tf.reduce_mean(tf.math.square(ref_traj-result))
        grads = tape.gradient(loss, func.trainable_variables)
        grad_vars = zip(grads, func.trainable_variables)
        optimizer.apply_gradients(grad_vars)
        print(func.k1.numpy(), func.k2.numpy(), func.k3.numpy())
        plt.figure()
        for r in tf.transpose(result):
            plt.plot(r)
        plt.plot(ref_traj)
        plt.savefig('test_plot{:03}.png'.format(i))
        plt.close()

You can watch the guessed parameters approach the correct parameters over time.

In your system, it is actually possible to compute an exact solution to the ODE. Using MATLAB and s(t=0)=[1, 0, 0], I got this solution: s2(t) = 1 - (k1*sinh((t*(k1^2 + 2*k1*k2 - 2*k1*k3 + k2^2 + 2*k2*k3 + k3^2)^(1/2))/2)*exp(-(k1*t)/2)*exp(-(k2*t)/2)*exp(-(k3*t)/2))/(k1^2 + 2*k1*k2 - 2*k1*k3 + k2^2 + 2*k2*k3 + k3^2)^(1/2) - (k2*sinh((t*(k1^2 + 2*k1*k2 - 2*k1*k3 + k2^2 + 2*k2*k3 + k3^2)^(1/2))/2)*exp(-(k1*t)/2)*exp(-(k2*t)/2)*exp(-(k3*t)/2))/(k1^2 + 2*k1*k2 - 2*k1*k3 + k2^2 + 2*k2*k3 + k3^2)^(1/2) - (k3*sinh((t*(k1^2 + 2*k1*k2 - 2*k1*k3 + k2^2 + 2*k2*k3 + k3^2)^(1/2))/2)*exp(-(k1*t)/2)*exp(-(k2*t)/2)*exp(-(k3*t)/2))/(k1^2 + 2*k1*k2 - 2*k1*k3 + k2^2 + 2*k2*k3 + k3^2)^(1/2) - exp(-(k1*t)/2)*exp(-(k2*t)/2)*exp(-(k3*t)/2)*cosh((t*(k1^2 + 2*k1*k2 - 2*k1*k3 + k2^2 + 2*k2*k3 + k3^2)^(1/2))/2)

You could now fit the parameters k1, k2, k3 of this curve to your observation using standard non-linear regression.

Finally, I'm definitely interested in creating a pull request, before that I'll clean up my code. At the moment, some solvers still throw dtype errors, I'll try to fix them first.

Jhsmit commented 4 years ago

Thanks both for your very detailed replies. I've been busy with other things last couple days and haven't yet had time to look into this in more detail.

@titu1994 , I can't open your google colab link, it might have expired. I've checked the solution with your Julia implementation, the rk4 solution of the AstroNN library is correct, but the rk4 solution of this library is indeed incorrect.

Although Julia might be the best option for differential equations, I would prefer the final package that I'm developing to be python only.

I think indeed what I wanted to is a single forward pass of each Ki. I'm not sure how to vectorize this. I'm planning to the put the paper associated to my project on biorxiv soon, when this is done its easier to provide an example with some more details on what the K[N, 3] matrix represents.

@KuzMenachem I've indeed also calculated the exact solution in Mathematica, but at the time I decided against this in favor of trying neural ODEs. This was mostly because in the future I might want to extend the ODE system to include more states, and this would be much easier if the numerical solution is in place.

The example you provide is indeed what I'd like to do, but then for N values of k1,k2,k3 in parallel. But this would in this example take an exorbitant amount of time.

TimFelixBeyer commented 4 years ago

I believe this might do what you want:

import tensorflow as tf
import matplotlib.pyplot as plt
from tfdiffeq import odeint

class Kinetics(tf.keras.Model):

    def __init__(self, k1, k2, k3):
        super().__init__()
        self.k1, self.k2, self.k3 = k1, k2, k3

    @tf.function
    def call(self, t, y):
        s0, s1, s2 = y[:,0], y[:,1], y[:,2]

        d_0 = - self.k1 * s0 + self.k2 * s1
        d_1 = - self.k2 * s1 - self.k3 * s1 + self.k1 * s0
        d_2 = self.k3 * s1
        return tf.stack([d_0, d_1, d_2], axis=-1)

with tf.device('/gpu:0'):
    tf.keras.backend.set_floatx('float64')
    NUM_SAMPLES = 100
    N = 200
    t = tf.linspace(0., 10., num=NUM_SAMPLES)
    # create a (N, 3) matrix of initial conditions
    y_init = tf.ones(shape=(N, 1), dtype=tf.float64) * tf.constant([1.,0., 0.], dtype=tf.float64)
    reference_k1s = tf.ones(shape=(N), dtype=tf.float64)
    reference_k2s = 2. * tf.ones(shape=(N), dtype=tf.float64) + tf.cast(tf.linspace(0., 1., N), dtype=tf.float64)
    reference_k3s = 3. * tf.ones(shape=(N), dtype=tf.float64)

    # Compute the reference trajectory
    ref_func = Kinetics(reference_k1s, reference_k2s, reference_k3s)
    ref_traj = odeint(ref_func, y_init, t)
    # Set up the model which we want to fit to the reference
    k1 = tf.Variable(tf.ones(shape=(N), dtype=tf.float64), trainable=True, dtype=tf.float64)
    k2 = tf.Variable(tf.ones(shape=(N), dtype=tf.float64), trainable=True, dtype=tf.float64)
    k3 = tf.Variable(tf.ones(shape=(N), dtype=tf.float64), trainable=True, dtype=tf.float64)
    func = Kinetics(k1, k2, k3)
    optimizer = tf.keras.optimizers.Adam(0.1, clipvalue=0.5)

    # Do 100 gradient descent steps
    for i in range(100):
        with tf.GradientTape() as tape:
            result = odeint(func, y_init, t)
            loss = tf.reduce_mean(tf.math.square(ref_traj-result))
        grads = tape.gradient(loss, func.trainable_variables)
        grad_vars = zip(grads, func.trainable_variables)
        optimizer.apply_gradients(grad_vars)
        print(func.k1.numpy(), func.k2.numpy(), func.k3.numpy())
        plt.figure()
        for r in tf.transpose(result[:, 0]):
            plt.plot(r)
        plt.plot(ref_traj[:, 0])
        plt.savefig('test_plot{:03}.png'.format(i))
        plt.close()
Jhsmit commented 4 years ago

Yes! I believe so too. I've just posted over at AstroNN (https://github.com/henrysky/astroNN/issues/10) where I've concatenated the k's together, I didn't realize odeint was shape compatible like that. Can't wait to try it out for real, but I've got to finish my homework first.

Thanks!