rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.6k stars 927 forks source link

Is there a more efficient way of sending additional inputs to func? #94

Closed ghost closed 4 years ago

ghost commented 4 years ago

Hello,

I was wondering whether there's a more efficient way of sending additional input parameters to ODEfunc in the forward pass, than appending the additional parameter as part of the tuple-state.

The additional input parameter I'm trying to send does not have a "temporal" evolution but is required for computing every integration step. Furthermore, the input parameter is the output of a earlier neural network, which needs to be backpropagated through in order for it to be trained.

Issue #23 and #64 mentions adding the additional parameter b to the state in tuple form, but the issue with this for me is that tensor b is orders of magnitude larger in size than the original state h (b.shape = [batch_size, 25600, 20], h.shape = [batch_size, 100, 20]) so it's very inefficient to include b as part of the state and solve a larger, sparser ODE just to make sure gradients are automatically propagated to b.

Is there a more efficient way to send additional input parameters that require gradients to ODEfunc? For instance, what if I create a new func module initialized with b every forward pass of the overarching model, like the following?

class Model(nn.Module):
    def forward(x):
         b = neural_net(x)
         h_init = Init(x)             // Init = differentiable function that computes the initial state
         func = ode_func(b)           // new ode_func module initialized with input parameter b
         output = odeint(ode_func, h_init, t)

class ode_func(nn.Module):
    def __init__(self, b):
         self.b = b
    def forward(self, t, h):
         return ode_function(t, h, b)

Thank you in advance :)

rtqichen commented 4 years ago

What if you dynamically create a nn.Module that takes b as an argument for its initialization? Then feed this into torchdiffeq as the odefunc.

ghost commented 4 years ago

@rtqichen That was exactly what I was thinking, but I was not sure whether the torchdiffeq package supports gradient backpropgation through parameters that aren't given explicitly as part of the state. Would there be anything in the package that prevents this from happening? For example, the forward pass of odeint is enhoused within a with torch.no_grad():.

rtqichen commented 4 years ago

(Oops, I realized I didn't actually fully read your issue before replying!)

The adjoint method only provides gradients for odefunc.parameters(), so I think you might have to wrap b inside a nn.Parameter. This might break the PyTorch automatic differentiation framework, but a simple workaround is to do your own chain rule using torch.autograd.grad. It should also be possible to do all this inside a custom autograd.Function, so you can write it once and not think anymore about the gradient stitching.

The non-adjoint backprop through solver odeint should work without any of this, but it sounds like that will use up too much memory in your use case.

ghost commented 4 years ago

@rtqichen thank you! The nn.Parameter trick worked, but code hasnt experienced any noticable speed boost. The bottleneck is in the backward pass which strangely hasnt sped up after having made this change :((

rtqichen commented 4 years ago

Ah right. Well the backward pass puts all parameters into the state anyway. So even the memory gained from doing this is only evident in the forward pass.