DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.35k stars 125 forks source link

Passing additional arguments to odeint #118

Closed MaxH1996 closed 2 years ago

MaxH1996 commented 2 years ago

Hi, I am currently learning about Neural ODEs, and to this end I have been experimenting with torchdyn. One problem I have encountered is having multiple arguments in the forward of the neural network. So for example I have some neural net called Net and I have instantiated it like so:

func = Net()

Now I have the callable, but say the forward pass of Net takes four arguments:


def forward(self, t, x, u, v):

  .....my_code.....

 return output

How can I now pass this func to odeint? In other words, can I pass additional arguments to odeint like you can for scipy's solve_ivp? I have tried using functools partial, but then I cannot use the adjoint method.

Any help would be appreciated :)

Zymrael commented 2 years ago

I've modified ODEProblem and NeuralODE with a new optimizable_parameters argument that can be used to pass parameters explicitly. That way you'll be able to use functools.partial even with NeuralODE and ODEProblem

Could you check if the following works with the new (no pip) version of torchdyn? If so we'll add these changes to the next pip release soon.

import torch
import torch.nn as nn
from torch.autograd import grad

import torchdyn
from torchdyn.numerics import odeint
from torchdyn.core import ODEProblem, NeuralODE
from functools import partial

l = nn.Linear(1, 1)

class TFunc(nn.Module):
    def __init__(self, l):
        super().__init__()
        self.l = l
    def forward(self, t, x, u, v, z):
        return self.l(x + u + v + z)

tfunc = TFunc(l)

u = v = z = torch.randn(1, 1)
f = partial(tfunc.forward, u=u, v=v, z=z)
x0 = torch.randn(1, 1, requires_grad=True)

# functional and `ODEProb` integration

t_eval, sol1 = odeint(f, x0, torch.linspace(0, 5, 10), solver='euler')

odeprob = NeuralODE(f, 'euler', sensitivity='interpolated_adjoint', optimizable_params=tfunc.parameters())
t_eval, sol2 = odeprob(x0, t_span=torch.linspace(0, 5, 10))

# should be the same, and the second now can be used to backprop to `tfunc.parameters()`
(sol1==sol2).all()
grad(sol2.sum(), x0)
MaxH1996 commented 2 years ago

Thanks for the quick response! Yes, this works for me. Just to make sure I am understanding everything correctly, sol2 can now be used for backprop and the network I use can be more or less arbitrary, correct? That is, TFunc can for example be a GCN with multiple inputs in the forward?

Zymrael commented 2 years ago

Correct! Thanks for using torchdyn, closing the issue for now.