rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.61k stars 930 forks source link

Support of Inverse Problems (e.g. Lorenz System) #190

Open kochlisGit opened 2 years ago

kochlisGit commented 2 years ago

Does torchdiffeq support inverse problem solving?

For example, can it compute the sigma, rho, beta parameters in the Lorenz system? The solution of the system, with the known parameters (10, 28, 8.3) is:

@torch.jit.script
class LorenzODE(torch.nn.Module):

    def __init__(self):
        super(LorenzODE, self).__init__()
        self.sigma = nn.Parameter(torch.as_tensor([10.0]))
        self.rho = nn.Parameter(torch.as_tensor([28.0]))
        self.beta = nn.Parameter(torch.as_tensor([2.66]))

    def forward(self, t, u):
        x, y, z = u[0],u[1],u[2]
        du1 = self.sigma[0] * (y - x)
        du2 = x * (self.rho[0] - z) - y
        du3 = x * y - self.beta[0] * z
        return torch.stack([du1, du2, du3])

u0 = torch.tensor([1.0,0.0,0.0])
t = torch.linspace(0, 100, 1001)
odeint(LorenzODE(), u0, t, rtol = 1e-8, atol=1e-8)
Saltsmart commented 2 years ago

You can take a look at UDE (paper), which mentioned a similar system.

bhalazs commented 2 years ago

Hi,

I've found a way to make ODE parameter fitting work. Below is a minimal working example that fits parameter 'p' in the differential equation dx/dt = 1 - p * exp(x). I've used the pytorch-minimize library (https://github.com/rfeinman/pytorch-minimize) to solve the problem with an exact Newton optimizer. It is highly recommended for parameter fitting of ODEs to use something faster and more accurate than the (mostly) gradient-descent based optimizers in torch.optim, unless the number of parameters is really high (e.g. >1000).

The main idea is to add 'x0' and 't' as class variables, define the ODE in the forward call and add a class method 'solve_ode' that only takes the parameter as its input, based on #129. This can be then used to construct an objective function that only takes the parameters to be optimized as input. (The wrapper fn around objfn may not be necessary strictly speaking.)

import torch
import torchmin
from torchdiffeq import odeint

class ode_fn(torch.nn.Module):

    def __init__(self, x0, t):

        super(ode_fn, self).__init__()        

        self.p = None
        self.x0 = x0
        self.t = t

    def forward(self, t, x):

        return 1 - self.p * torch.exp(x)

    def solve_ode(self, p):

        self.p = p

        return odeint(self.forward, self.x0, self.t)

#Generate data

x0 = torch.ones(1)
t = torch.linspace(0, 1, 4)
model = ode_fn(x0, t)

with torch.no_grad():
    x_true = model.solve_ode(p = torch.ones(1) * 0.1)

#Optimization

def obj_fn(x_true):
    def obj_fn_(p):
        x_pred = model.solve_ode(p)
        loss = torch.nn.functional.mse_loss(x_pred, x_true)
        return loss
    return obj_fn_

#x0 is the initial guess for 'p'
results = torchmin.minimize(obj_fn(x_true), x0 = torch.ones(1), method="newton-exact", 
                            tol=1e-6, max_iter=50, disp=2)