rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.39k stars 910 forks source link

Support higher order autodiff? #29

Open woct0rdho opened 5 years ago

woct0rdho commented 5 years ago

Thanks a lot for your work! However, it seems that backward after grad is not supported yet. Here is a minimal example:

# dy/dt = a b y, t = 0...1
# y1 = y0 exp(a b)
# dy1/da = b y0 exp(a b)
# dy1/dy0 = exp(a b)

import torch
from torch import nn
from torch.autograd import grad
from torchdiffeq import odeint_adjoint as odeint

class Func(nn.Module):
    def __init__(self):
        super(Func, self).__init__()
        self.a = nn.Parameter(torch.tensor(2.0))
        self.b = nn.Parameter(torch.tensor(3.0))

    def forward(self, t, y):
        return self.a * self.b * y

if __name__ == '__main__':
    func = Func()
    y0 = torch.tensor(4.0, requires_grad=True)
    t = torch.tensor([0.0, 1.0])
    y1 = odeint(func, y0, t)[1]
    print(y1)
    y1.backward(retain_graph=True)

    dy1_da = grad(y1, func.a, create_graph=True)[0]
    print(dy1_da)
    dy1_da.backward(retain_graph=True)

    dy1_dy0 = grad(y1, y0, create_graph=True)[0]
    print(dy1_dy0)
    dy1_dy0.backward(retain_graph=True)

Both dy1_da and dy1_dy0 do not have grad_fn, then dy1_da.backward and dy1_dy0.backward throw errors. It would be nice if you could support these operations, then we could build more complex applications on your package.

rtqichen commented 5 years ago

Yes, this is on a TODO for now. https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/adjoint.py#L31 I'd need to get a bit finicky with pytorch's Function.

For now, I think the non-adjoint version should support higher-order autodiffs.

Sceki commented 2 years ago

Hi @rtqichen, thank you for your work! So how can we currently take high order gradients of the integration outputs with respect to the inputs (e.g. initial conditions or else)?

EyalRozenberg1 commented 1 year ago

Hello @rtqichen. Was there a progress in higher order autodiff feature using the adjoint method?

rtqichen commented 1 year ago

No sorry, zero progress has been made since 2019. If anyone wants to submit a PR for this, I can approve it.

EyalRozenberg1 commented 1 year ago

Thanks for your comment, Ricky. Are there any action items that should be taken? Eyal

wangmiaowei commented 5 months ago

@rtqichen Thanks for your work. But The GPU memory is lost when I use import torch.autograd as ag. My code is as follows:

for itr in range(1, 5):
        print('iteration: ',itr)
        print("start_memory_allcoated(MB) {}".format(torch.cuda.memory_allocated()/1048576))
        optimizer.zero_grad()
        batch_y0, batch_t, batch_y = get_batch()
        batch_y0.requires_grad  = True
        print('batch_y0.shape: ',batch_y0.shape)
        pred_y = odeint(func, batch_y0, batch_t).to(device)

My results are as follows:

iteration:  1
start_memory_allcoated(MB) 0.0146484375
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0244140625
iteration:  2
start_memory_allcoated(MB) 0.0244140625
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.025390625
iteration:  3
start_memory_allcoated(MB) 0.025390625
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0263671875
iteration:  4
start_memory_allcoated(MB) 0.0263671875
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.02734375
(base) [s2608314@node3c01(eddie) networks]$ python test.py
iteration:  1
start_memory_allcoated(MB) 0.0146484375
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0244140625
iteration:  2
start_memory_allcoated(MB) 0.0244140625
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.025390625
iteration:  3
start_memory_allcoated(MB) 0.025390625
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0263671875
iteration:  4
start_memory_allcoated(MB) 0.0263671875
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.02734375
iteration:  5
start_memory_allcoated(MB) 0.02734375
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0283203125
iteration:  6
start_memory_allcoated(MB) 0.0283203125
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.029296875
iteration:  7
start_memory_allcoated(MB) 0.029296875
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0302734375
iteration:  8
start_memory_allcoated(MB) 0.0302734375
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.03125
iteration:  9
start_memory_allcoated(MB) 0.03125
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0322265625

And it really causes out of memory! Hope to get your reply as soon as possible.