rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.61k stars 930 forks source link

No grad_fun error for Adjoint method! #209

Closed nkur closed 2 years ago

nkur commented 2 years ago

Hi! I'm getting this error for the adjoint method where it is not preserving the grad_fun. The normal odeint works fine though.

Essentially, what's happening is that the initial conditions are passed with require_grad=True, and the output preserves the grad fun in the first call. In the second call, however, the states passed to the sys function do not have the grad_fun anymore. So when they go to the control function, they neither have a grad_fun nor have the require_grad enabled. Now, I can manually set require_grad=True for q in control or for y in sys, but I believe that creates a new graph from that point on, and thus rendering backprop useless.

Here's the code, with the error

import numpy as np
import torch.nn as nn
import torch.optim as optim

from torchdiffeq import odeint
from torchdiffeq import odeint_adjoint

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 100
batch_sz = torch.tensor(batch_size, requires_grad=True, dtype=torch.float32).to(device)

# Defining the NN model
class Vnn(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Vnn, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size[0], bias=False)
        self.softplus = nn.Softplus()
        self.fc2 = nn.Linear(hidden_size[0], hidden_size[1], bias=False)
        self.tanh = nn.Tanh()
        self.fc3 = nn.Linear(hidden_size[1], output_size, bias=False)

        # Weight initialization
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc3.weight) 

    def forward(self, x):
        out1 = self.fc1(x)
        out2 = self.softplus(out1)
        out3 = self.fc2(out2)
        out4 = self.tanh(out3)
        out5 = self.fc3(out4)
        out6 = self.tanh(out5)
        return out6

class controlFun(nn.Module):

    def __init__(self, vnn):
        super(controlFun, self).__init__()
        self.vnn = vnn

    def forward(self, t, q):
        with torch.enable_grad():
            v_q = self.vnn(q.unsqueeze(dim=0))

            dv = torch.autograd.grad(v_q, q, create_graph=True, retain_graph=True)[0]

            return -dv

class sysModel(nn.Module):

    def __init__(self):
        super(sysModel, self).__init__()

    def forward(self, t, y):
        with torch.enable_grad():
            x, z = torch.split(y, [1, 1], dim=0)
#             print(x, y, z)
            u = control(t, x)

            op = torch.cat((-9.82 * torch.abs(u), z), dim=0)
            return op

v_nn = Vnn(input_size=1, hidden_size=[64, 64], output_size=1).to(device)
control = controlFun(v_nn).to(device)
sys = sysModel().to(device)
params = control.parameters()
optimizer = optim.Adam(params, lr=0.001)

u0 = np.random.uniform(-2*np.pi, 2*np.pi, (batch_size, 2, 1))
init_states = torch.tensor(u0, requires_grad=True, dtype=torch.float32).to(device)
t_eval = torch.tensor(np.arange(0, 3, 0.01), requires_grad=True, dtype=torch.float32).to(device)

u_opt = 10

for i in range(init_states.shape[0]):

    u_sol = odeint_adjoint(sys, init_states[i, :, :].clone().detach().requires_grad_(True),
                              t_eval, method='dopri5').squeeze(2).to(device) 

    loss = u_sol.sum() - u_opt

print(loss)
loss.backward()
optimizer.step(); optimizer.zero_grad()

Error:

RuntimeError Traceback (most recent call last) Input In [8], in <cell line: 83>() 81 u_opt = 10 83 for i in range(init_states.shape[0]): ---> 85 u_sol = odeint_adjoint(sys, init_states[i, :, :].clone().detach().requiresgrad(True), 86 t_eval, method='dopri5').squeeze(2).to(device) 88 loss = u_sol.sum() - u_opt 90 print(loss)

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\adjoint.py:198, in odeint_adjoint(func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method, adjoint_options, adjoint_params) 195 state_norm = options["norm"] 196 handle_adjointnorm(adjoint_options, shapes, state_norm) --> 198 ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, 199 adjoint_method, adjoint_options, t.requires_grad, *adjoint_params) 201 if event_fn is None: 202 solution = ans

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\adjoint.py:25, in OdeintAdjointMethod.forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method, adjoint_options, t_requires_grad, *adjoint_params) 22 ctx.event_mode = event_fn is not None 24 with torch.no_grad(): ---> 25 ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn) 27 if event_fn is None: 28 y = ans

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\odeint.py:77, in odeint(func, y0, t, rtol, atol, method, options, event_fn) 74 solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options) 76 if event_fn is None: ---> 77 solution = solver.integrate(t) 78 else: 79 event_t, solution = solver.integrate_until_event(t[0], event_fn)

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\solvers.py:28, in AdaptiveStepsizeODESolver.integrate(self, t) 26 solution[0] = self.y0 27 t = t.to(self.dtype) ---> 28 self._before_integrate(t) 29 for i in range(1, len(t)): 30 solution[i] = self._advance(t[i])

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\rk_common.py:163, in RKAdaptiveStepsizeODESolver._before_integrate(self, t) 161 f0 = self.func(t[0], self.y0) 162 if self.first_step is None: --> 163 first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol, 164 self.norm, f0=f0) 165 else: 166 first_step = self.first_step

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\misc.py:62, in _select_initial_step(func, t0, y0, order, rtol, atol, norm, f0) 59 h0 = 0.01 d0 / d1 61 y1 = y0 + h0 f0 ---> 62 f1 = func(t0 + h0, y1) 64 d2 = norm((f1 - f0) / scale) / h0 66 if d1 <= 1e-15 and d2 <= 1e-15:

File ~\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\misc.py:189, in _PerturbFunc.forward(self, t, y, perturb) 186 else: 187 # Do nothing. 188 pass --> 189 return self.base_func(t, y)

File ~\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\misc.py:189, in _PerturbFunc.forward(self, t, y, perturb) 186 else: 187 # Do nothing. 188 pass --> 189 return self.base_func(t, y)

File ~\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

Input In [8], in sysModel.forward(self, t, y) 62 x, z = torch.split(y, [1, 1], dim=0) 63 # print(x, y, z) ---> 64 u = control(t, x) 66 op = torch.cat((-9.82 * torch.abs(u), z), dim=0) 67 return op

File ~\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

Input In [8], in controlFun.forward(self, t, q) 46 with torch.enable_grad(): 47 # q, p = torch.split(y, [1, 1], dim=0) 48 v_q = self.vnn(q.unsqueeze(dim=0)) ---> 50 dv = torch.autograd.grad(v_q, q, create_graph=True, retain_graph=True)[0] 52 return -dv

File ~\anaconda3\envs\torch\lib\site-packages\torch\autogradinit.py:276, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched) 274 return _vmap_internals.vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs) 275 else: --> 276 return Variable.execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 277 t_outputs, grad_outputs, retain_graph, create_graph, t_inputs, 278 allow_unused, accumulate_grad=False)

RuntimeError: One of the differentiated Tensors does not require grad


Thanks!

rtqichen commented 2 years ago

You just need to add q.requires_grad_(True) before computing v_q inside the control function.

Also, it's good practice to put all the parameters inside a single Module. This helps ensure odeint_adjoint backpropagates into the v_nn network. (It's kind of dumb in that it won't track parameters outside of the ODE function, in this case sys.)

This code seems to run:

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from torchdiffeq import odeint
from torchdiffeq import odeint_adjoint

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 100
batch_sz = torch.tensor(batch_size, requires_grad=True, dtype=torch.float32).to(device)

# Defining the NN model
class Vnn(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Vnn, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size[0], bias=False)
        self.softplus = nn.Softplus()
        self.fc2 = nn.Linear(hidden_size[0], hidden_size[1], bias=False)
        self.tanh = nn.Tanh()
        self.fc3 = nn.Linear(hidden_size[1], output_size, bias=False)

        # Weight initialization
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc3.weight) 

    def forward(self, x):
        out1 = self.fc1(x)
        out2 = self.softplus(out1)
        out3 = self.fc2(out2)
        out4 = self.tanh(out3)
        out5 = self.fc3(out4)
        out6 = self.tanh(out5)
        return out6

class sysModel(nn.Module):

    def __init__(self, vnn):
        super(sysModel, self).__init__()
        self.vnn = vnn

    def control_fn(self, t, q):
        with torch.enable_grad():
            q = q.requires_grad_(True)
            v_q = self.vnn(q.unsqueeze(dim=0))

            dv = torch.autograd.grad(v_q, q, create_graph=True, retain_graph=True)[0]

            return -dv

    def forward(self, t, y):
        with torch.enable_grad():
            x, z = torch.split(y, [1, 1], dim=0)
#             print(x, y, z)
            u = self.control_fn(t, x)

            op = torch.cat((-9.82 * torch.abs(u), z), dim=0)
            return op

v_nn = Vnn(input_size=1, hidden_size=[64, 64], output_size=1).to(device)
sys = sysModel(v_nn).to(device)
optimizer = optim.Adam(sys.parameters(), lr=0.001)

u0 = np.random.uniform(-2*np.pi, 2*np.pi, (batch_size, 2, 1))
init_states = torch.tensor(u0, requires_grad=True, dtype=torch.float32).to(device)
t_eval = torch.tensor(np.arange(0, 3, 0.01), requires_grad=True, dtype=torch.float32).to(device)

u_opt = 10

for i in range(init_states.shape[0]):

    u_sol = odeint_adjoint(sys, init_states[i, :, :], t_eval, method='dopri5').squeeze(2)

    loss = u_sol.sum() - u_opt

print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
nkur commented 2 years ago

I tried that, but the issue now is that backprop doesn't work (loss stays the same through epochs). It seems after I set q.requires_grad_(True), it creates a new graph from that point onwards; it cannot backprop through the whole trajectory, but I could be wrong.

The difference between odeint and odeint_adjoint seems to be that the states (y in this case) in sys is not part of any graph in the second call in the adjoint method. Thus loss calculated at the end is just a difference between u_opt and the last value of u_sol, which is not related to any of it's previous values through any graph.

The code I posted above is just a boilerplate for another bigger code, but it retains the same error.

Sorry, my explanation is a bit sloppy, but I hope I got the idea across.

rtqichen commented 2 years ago

Yes, check out my modification to the way sys and control_fn are constructed. The original code won't be able to differentiate through to the parameters of the control fn when using odeint_adjoint.

q.requires_grad_(True) won't mess with the gradients. Only detach will modify gradients.

Edit: does everything work when using odeint?

nkur commented 2 years ago

IT WORKS NOW!! Thank you so much.

Odeint was able to track the parameters in the separate fun; for adjoint, putting it all in one was the way to go. I feel stupid to have tried everything but this. Thanks again!

rtqichen commented 2 years ago

No problem. Sorry the documentation around these gotchas is not very clear.

wangmiaowei commented 9 months ago

Sorry up to now, is there any solution for from torchdiffeq import odeint_adjoint?