rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.36k stars 907 forks source link

Follow up on issue #129 passing arguments into odeint #178

Open sagastra opened 2 years ago

sagastra commented 2 years ago

Hi,

I have instantiated a class like in issue #129 where I am passing in w and updating it.

class BNM(nn.Module):

def __init__(self, w):
    super(BNM, self).__init__()
    self.w = nn.Parameter(w)

def forward(self, t, x):
    return -x + torch.matmul(x, self.w)

def solve_ode(self, x0, t, w):
    self.w = w
    return odeint(self, x0, t)

In another class in the forward method, I have for the ode_adjoint the following line, with declaring self.SC as a parameter with a requires_grad = True:

def init(self, ode_func): self.SC = nn.Parameter(torch.randn(10,10), requires_grad =True) self.ode_func = ode_func

def forward(self, x): x0 = LSTM(x) out = ode_adjoint(self.ode_func, x0, t= torch.linspace(0, 0.72, 10), adjoint_params=self.SC)

The pytorch graph however does not attach self.SC in the computational graph and the gradients for self.SC are all zero. However the gradients for out are fine. Do you have a suggestion on what I might be doing wrong?

Thanks Amrit

rtqichen commented 2 years ago

Are both examples not giving what you expect? I think both might be problematic:

    super(BNM, self).__init__()
    self.w = nn.Parameter(w)

def forward(self, t, x):
    return -x + torch.matmul(x, self.w)

def solve_ode(self, x0, t, w):
    self.w = w
    return odeint(self, x0, t)

Once you update self.w = w, self.w no longer points to the initial parameter nn.Parameter(w). But be careful if you expect the initial parameter to have gradients (e.g. if you've passed it to an optimizer).

def init(self, ode_func):
    self.SC = nn.Parameter(torch.randn(10,10), requires_grad =True)
    self.ode_func = ode_func

def forward(self, x):
    x0 = LSTM(x)
    out = ode_adjoint(self.ode_func, x0, t= torch.linspace(0, 0.72, 10), adjoint_params=self.SC)

What I'm confused about is whether self.ode_func depends on self.SC? If not, then out wouldn't depend on self.SC and zero gradients would be correct. Based on this init function, I wouldn't expect ode_func to depend on SC. Note in the first example, the ODE function forward always references self.w, which is updated every time. But in this second example, the ODE function ode_func doesn't reference self.SC.

sagastra commented 2 years ago

Sorry, I used all the code for one example. (not sure if this was misleading). Yes self.ode_func depends on self.SC. I basically want to send in a matrix into ode solve and have it updated. I used to use the Tensorflow odeint method, and there you can just pass in args and then the tensor you want to train on. Trying to achieve the same functionality. Maybe I remove the self.w = w, and make a new ode function every forward call with the updated self.SC?

sagastra commented 2 years ago

Hey, Still having trouble with this. I am basically trying to have the functionality of y= odint(f, xi, args=w) where w is a trainable parameter. The value of w gets passed down to f, and the gradients returned via the initial conditions xi, work perfectly. The only issue is that w does not seem to be connected to the auto grad graph, and returns zero gradients. Do you have any suggestions on achieving this functionality? Thanks Amrit

dgbsg commented 2 years ago

I have the same idea. Have you solved this problem?

sagastra commented 2 years ago

Nah I gave up and went back to Tensorflow

dgbsg commented 2 years ago

I think the code you ask questions uses encapsulated odeint and odeint_ Adjoint function, but in fact, the adjoint sensitivity algorithm adopted in algorithm2 in the appendix of 《neural ordinal differential equations 》seems to be able to solve the problem of parameter updating. I'd like to know your opinion on this

bhalazs commented 2 years ago

Hi,

I've found a way to make ODE parameter fitting work. Below is a minimal working example that fits parameter 'p' in the differential equation dx/dt = 1 - p * exp(x). I've used the pytorch-minimize library (https://github.com/rfeinman/pytorch-minimize) to solve the problem with an exact Newton optimizer. It is highly recommended for parameter fitting of ODEs to use something faster and more accurate than the (mostly) gradient-descent based optimizers in torch.optim, unless the number of parameters is really high (e.g. >1000).

The main idea is to add 'x0' and 't' as class variables, define the ODE in the forward call and add a class method 'solve_ode' that only takes the parameter as its input, based on #129 . This can be then used to construct an objective function that only takes the parameters to be optimized as input. (The wrapper fn around objfn may not be necessary strictly speaking.) It is definitely a workaround though, an "args" argument for passing parameters would indeed be nicer.

If you want an NN to give the value of a parameter, that can be done in a more Pytorch-like way, the implementation below is more applicable for fitting physical models. Please let me know if you need an example like that, I have a project that I can use for constructing a minimal example.

import torch
import torchmin
from torchdiffeq import odeint

class ode_fn(torch.nn.Module):

    def __init__(self, x0, t):

        super(ode_fn, self).__init__()        

        self.p = None
        self.x0 = x0
        self.t = t

    def forward(self, t, x):

        return 1 - self.p * torch.exp(x)

    def solve_ode(self, p):

        self.p = p

        return odeint(self.forward, self.x0, self.t)

#Generate data

x0 = torch.ones(1)
t = torch.linspace(0, 1, 4)
model = ode_fn(x0, t)

with torch.no_grad():
    x_true = model.solve_ode(p = torch.ones(1) * 0.1)

#Optimization

def obj_fn(x_true):
    def obj_fn_(p):
        x_pred = model.solve_ode(p)
        loss = torch.nn.functional.mse_loss(x_pred, x_true)
        return loss
    return obj_fn_

#x0 is the initial guess for 'p'
results = torchmin.minimize(obj_fn(x_true), x0 = torch.ones(1), method="newton-exact", 
                            tol=1e-6, max_iter=50, disp=2)