Closed ghost closed 4 years ago
What if you dynamically create a nn.Module that takes b
as an argument for its initialization? Then feed this into torchdiffeq as the odefunc.
@rtqichen That was exactly what I was thinking, but I was not sure whether the torchdiffeq
package supports gradient backpropgation through parameters that aren't given explicitly as part of the state. Would there be anything in the package that prevents this from happening? For example, the forward pass of odeint is enhoused within a with torch.no_grad():
.
(Oops, I realized I didn't actually fully read your issue before replying!)
The adjoint method only provides gradients for odefunc.parameters()
, so I think you might have to wrap b
inside a nn.Parameter
. This might break the PyTorch automatic differentiation framework, but a simple workaround is to do your own chain rule using torch.autograd.grad
. It should also be possible to do all this inside a custom autograd.Function, so you can write it once and not think anymore about the gradient stitching.
The non-adjoint backprop through solver odeint should work without any of this, but it sounds like that will use up too much memory in your use case.
@rtqichen thank you! The nn.Parameter trick worked, but code hasnt experienced any noticable speed boost. The bottleneck is in the backward pass which strangely hasnt sped up after having made this change :((
Ah right. Well the backward pass puts all parameters into the state anyway. So even the memory gained from doing this is only evident in the forward pass.
Hello,
I was wondering whether there's a more efficient way of sending additional input parameters to ODEfunc in the forward pass, than appending the additional parameter as part of the tuple-state.
The additional input parameter I'm trying to send does not have a "temporal" evolution but is required for computing every integration step. Furthermore, the input parameter is the output of a earlier neural network, which needs to be backpropagated through in order for it to be trained.
Issue #23 and #64 mentions adding the additional parameter
b
to the state in tuple form, but the issue with this for me is that tensorb
is orders of magnitude larger in size than the original stateh
(b.shape = [batch_size, 25600, 20]
,h.shape = [batch_size, 100, 20]
) so it's very inefficient to include b as part of the state and solve a larger, sparser ODE just to make sure gradients are automatically propagated to b.Is there a more efficient way to send additional input parameters that require gradients to ODEfunc? For instance, what if I create a new
func
module initialized withb
every forward pass of the overarching model, like the following?Thank you in advance :)