rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.55k stars 923 forks source link

No grad_fn error #151

Open hellomynameisjiji opened 3 years ago

hellomynameisjiji commented 3 years ago

Hi! I'm currently working on an experiment with ffjord, and Issues in the repository for ffjord seems not available at the moment, so I'm posting my issue here.

The following error keeps randomly popping up when my team and I train our models. RuntimeError: element 1 of tensors does not require grad and does not have a grad_fn.

It is apparently concerned with ode solver (in my case, I use dopri5), but I literally have no idea of how to fix or control this cuz it's so random. I'm guessing that there's a conflict in GPU in the case when I run multiple models simultaneously overly.

Thanks!

rtqichen commented 3 years ago

Sorry, without a reproducing code, I don't know where this error could come up. Make sure you've wrapped anything that needs to be differentiated through within a torch.enable_grad().

nkur commented 2 years ago

Hey! I'm getting the same error for the adjoint method. The normal odeint works fine though.

Essentially, what's happening is that the initial conditions are passed with require_grad=True, and the output preserves the grad fun in the first call. In the second call however, the states passed to the sys function do not have the grad_fun preserved. So when they go to control function, they neither have a grad_fun nor have the require_grad enabled. Now, I can manually set require_grad=True for q in control or for y in sys, but I believe that creates a new graph from that point on, and thus backprop is useless.

Here's the code, with the error

import numpy as np
import torch.nn as nn
import torch.optim as optim

from torchdiffeq import odeint
from torchdiffeq import odeint_adjoint

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 100
batch_sz = torch.tensor(batch_size, requires_grad=True, dtype=torch.float32).to(device)

# Defining the NN model
class Vnn(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Vnn, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size[0], bias=False)
        self.softplus = nn.Softplus()
        self.fc2 = nn.Linear(hidden_size[0], hidden_size[1], bias=False)
        self.tanh = nn.Tanh()
        self.fc3 = nn.Linear(hidden_size[1], output_size, bias=False)

        # Weight initialization
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc3.weight) 

    def forward(self, x):
        out1 = self.fc1(x)
        out2 = self.softplus(out1)
        out3 = self.fc2(out2)
        out4 = self.tanh(out3)
        out5 = self.fc3(out4)
        out6 = self.tanh(out5)
        return out6

class controlFun(nn.Module):

    def __init__(self, vnn):
        super(controlFun, self).__init__()
        self.vnn = vnn

    def forward(self, t, q):
        with torch.enable_grad():
            v_q = self.vnn(q.unsqueeze(dim=0))

            dv = torch.autograd.grad(v_q, q, create_graph=True, retain_graph=True)[0]

            return -dv

class sysModel(nn.Module):

    def __init__(self):
        super(sysModel, self).__init__()

    def forward(self, t, y):
        with torch.enable_grad():
            x, z = torch.split(y, [1, 1], dim=0)
#             print(x, y, z)
            u = control(t, x)

            op = torch.cat((-9.82 * torch.abs(u), z), dim=0)
            return op

v_nn = Vnn(input_size=1, hidden_size=[64, 64], output_size=1).to(device)
control = controlFun(v_nn).to(device)
sys = sysModel().to(device)
params = control.parameters()
optimizer = optim.Adam(params, lr=0.001)

u0 = np.random.uniform(-2*np.pi, 2*np.pi, (batch_size, 2, 1))
init_states = torch.tensor(u0, requires_grad=True, dtype=torch.float32).to(device)
t_eval = torch.tensor(np.arange(0, 3, 0.01), requires_grad=True, dtype=torch.float32).to(device)

u_opt = 10

for i in range(init_states.shape[0]):

    u_sol = odeint_adjoint(sys, init_states[i, :, :].clone().detach().requires_grad_(True),
                              t_eval, method='dopri5').squeeze(2).to(device) 

    loss = u_sol.sum() - u_opt

print(loss)
loss.backward()
optimizer.step(); optimizer.zero_grad()

Error:

RuntimeError Traceback (most recent call last) Input In [8], in <cell line: 83>() 81 u_opt = 10 83 for i in range(init_states.shape[0]): ---> 85 u_sol = odeint_adjoint(sys, init_states[i, :, :].clone().detach().requiresgrad(True), 86 t_eval, method='dopri5').squeeze(2).to(device) 88 loss = u_sol.sum() - u_opt 90 print(loss)

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\adjoint.py:198, in odeint_adjoint(func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method, adjoint_options, adjoint_params) 195 state_norm = options["norm"] 196 handle_adjointnorm(adjoint_options, shapes, state_norm) --> 198 ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, 199 adjoint_method, adjoint_options, t.requires_grad, *adjoint_params) 201 if event_fn is None: 202 solution = ans

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\adjoint.py:25, in OdeintAdjointMethod.forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method, adjoint_options, t_requires_grad, *adjoint_params) 22 ctx.event_mode = event_fn is not None 24 with torch.no_grad(): ---> 25 ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn) 27 if event_fn is None: 28 y = ans

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\odeint.py:77, in odeint(func, y0, t, rtol, atol, method, options, event_fn) 74 solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options) 76 if event_fn is None: ---> 77 solution = solver.integrate(t) 78 else: 79 event_t, solution = solver.integrate_until_event(t[0], event_fn)

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\solvers.py:28, in AdaptiveStepsizeODESolver.integrate(self, t) 26 solution[0] = self.y0 27 t = t.to(self.dtype) ---> 28 self._before_integrate(t) 29 for i in range(1, len(t)): 30 solution[i] = self._advance(t[i])

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\rk_common.py:163, in RKAdaptiveStepsizeODESolver._before_integrate(self, t) 161 f0 = self.func(t[0], self.y0) 162 if self.first_step is None: --> 163 first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol, 164 self.norm, f0=f0) 165 else: 166 first_step = self.first_step

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\misc.py:62, in _select_initial_step(func, t0, y0, order, rtol, atol, norm, f0) 59 h0 = 0.01 d0 / d1 61 y1 = y0 + h0 f0 ---> 62 f1 = func(t0 + h0, y1) 64 d2 = norm((f1 - f0) / scale) / h0 66 if d1 <= 1e-15 and d2 <= 1e-15:

File ~\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\misc.py:189, in _PerturbFunc.forward(self, t, y, perturb) 186 else: 187 # Do nothing. 188 pass --> 189 return self.base_func(t, y)

File ~\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\misc.py:189, in _PerturbFunc.forward(self, t, y, perturb) 186 else: 187 # Do nothing. 188 pass --> 189 return self.base_func(t, y)

File ~\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

Input In [8], in sysModel.forward(self, t, y) 62 x, z = torch.split(y, [1, 1], dim=0) 63 # print(x, y, z) ---> 64 u = control(t, x) 66 op = torch.cat((-9.82 * torch.abs(u), z), dim=0) 67 return op

File ~\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

Input In [8], in controlFun.forward(self, t, q) 46 with torch.enable_grad(): 47 # q, p = torch.split(y, [1, 1], dim=0) 48 v_q = self.vnn(q.unsqueeze(dim=0)) ---> 50 dv = torch.autograd.grad(v_q, q, create_graph=True, retain_graph=True)[0] 52 return -dv

File ~\anaconda3\envs\torch\lib\site-packages\torch\autograd__init__.py:276, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched) 274 return _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(gradoutputs) 275 else: --> 276 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 277 t_outputs, gradoutputs, retain_graph, create_graph, t_inputs, 278 allow_unused, accumulate_grad=False)

RuntimeError: One of the differentiated Tensors does not require grad

ali-pr1 commented 2 months ago

I'm having the same problem with adjoint method. Anyone got a solution for it?