Closed nkur closed 2 years ago
You just need to add q.requires_grad_(True)
before computing v_q inside the control function.
Also, it's good practice to put all the parameters inside a single Module. This helps ensure odeint_adjoint backpropagates into the v_nn
network. (It's kind of dumb in that it won't track parameters outside of the ODE function, in this case sys
.)
This code seems to run:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchdiffeq import odeint
from torchdiffeq import odeint_adjoint
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 100
batch_sz = torch.tensor(batch_size, requires_grad=True, dtype=torch.float32).to(device)
# Defining the NN model
class Vnn(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Vnn, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size[0], bias=False)
self.softplus = nn.Softplus()
self.fc2 = nn.Linear(hidden_size[0], hidden_size[1], bias=False)
self.tanh = nn.Tanh()
self.fc3 = nn.Linear(hidden_size[1], output_size, bias=False)
# Weight initialization
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.zeros_(self.fc3.weight)
def forward(self, x):
out1 = self.fc1(x)
out2 = self.softplus(out1)
out3 = self.fc2(out2)
out4 = self.tanh(out3)
out5 = self.fc3(out4)
out6 = self.tanh(out5)
return out6
class sysModel(nn.Module):
def __init__(self, vnn):
super(sysModel, self).__init__()
self.vnn = vnn
def control_fn(self, t, q):
with torch.enable_grad():
q = q.requires_grad_(True)
v_q = self.vnn(q.unsqueeze(dim=0))
dv = torch.autograd.grad(v_q, q, create_graph=True, retain_graph=True)[0]
return -dv
def forward(self, t, y):
with torch.enable_grad():
x, z = torch.split(y, [1, 1], dim=0)
# print(x, y, z)
u = self.control_fn(t, x)
op = torch.cat((-9.82 * torch.abs(u), z), dim=0)
return op
v_nn = Vnn(input_size=1, hidden_size=[64, 64], output_size=1).to(device)
sys = sysModel(v_nn).to(device)
optimizer = optim.Adam(sys.parameters(), lr=0.001)
u0 = np.random.uniform(-2*np.pi, 2*np.pi, (batch_size, 2, 1))
init_states = torch.tensor(u0, requires_grad=True, dtype=torch.float32).to(device)
t_eval = torch.tensor(np.arange(0, 3, 0.01), requires_grad=True, dtype=torch.float32).to(device)
u_opt = 10
for i in range(init_states.shape[0]):
u_sol = odeint_adjoint(sys, init_states[i, :, :], t_eval, method='dopri5').squeeze(2)
loss = u_sol.sum() - u_opt
print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
I tried that, but the issue now is that backprop doesn't work (loss stays the same through epochs). It seems after I set q.requires_grad_(True)
, it creates a new graph from that point onwards; it cannot backprop through the whole trajectory, but I could be wrong.
The difference between odeint and odeint_adjoint seems to be that the states (y
in this case) in sys
is not part of any graph in the second call in the adjoint method. Thus loss calculated at the end is just a difference between u_opt
and the last value of u_sol
, which is not related to any of it's previous values through any graph.
The code I posted above is just a boilerplate for another bigger code, but it retains the same error.
Sorry, my explanation is a bit sloppy, but I hope I got the idea across.
Yes, check out my modification to the way sys
and control_fn
are constructed. The original code won't be able to differentiate through to the parameters of the control fn when using odeint_adjoint.
q.requires_grad_(True)
won't mess with the gradients. Only detach will modify gradients.
Edit: does everything work when using odeint?
IT WORKS NOW!! Thank you so much.
Odeint was able to track the parameters in the separate fun; for adjoint, putting it all in one was the way to go. I feel stupid to have tried everything but this. Thanks again!
No problem. Sorry the documentation around these gotchas is not very clear.
Sorry up to now, is there any solution for from torchdiffeq import odeint_adjoint?
Hi! I'm getting this error for the adjoint method where it is not preserving the grad_fun. The normal odeint works fine though.
Essentially, what's happening is that the initial conditions are passed with require_grad=True, and the output preserves the grad fun in the first call. In the second call, however, the states passed to the sys function do not have the grad_fun anymore. So when they go to the control function, they neither have a grad_fun nor have the require_grad enabled. Now, I can manually set require_grad=True for q in control or for y in sys, but I believe that creates a new graph from that point on, and thus rendering backprop useless.
Here's the code, with the error
Error:
RuntimeError Traceback (most recent call last) Input In [8], in <cell line: 83>() 81 u_opt = 10 83 for i in range(init_states.shape[0]): ---> 85 u_sol = odeint_adjoint(sys, init_states[i, :, :].clone().detach().requiresgrad(True), 86 t_eval, method='dopri5').squeeze(2).to(device) 88 loss = u_sol.sum() - u_opt 90 print(loss)
File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\adjoint.py:198, in odeint_adjoint(func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method, adjoint_options, adjoint_params) 195 state_norm = options["norm"] 196 handle_adjointnorm(adjoint_options, shapes, state_norm) --> 198 ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, 199 adjoint_method, adjoint_options, t.requires_grad, *adjoint_params) 201 if event_fn is None: 202 solution = ans
File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\adjoint.py:25, in OdeintAdjointMethod.forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method, adjoint_options, t_requires_grad, *adjoint_params) 22 ctx.event_mode = event_fn is not None 24 with torch.no_grad(): ---> 25 ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn) 27 if event_fn is None: 28 y = ans
File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\odeint.py:77, in odeint(func, y0, t, rtol, atol, method, options, event_fn) 74 solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options) 76 if event_fn is None: ---> 77 solution = solver.integrate(t) 78 else: 79 event_t, solution = solver.integrate_until_event(t[0], event_fn)
File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\solvers.py:28, in AdaptiveStepsizeODESolver.integrate(self, t) 26 solution[0] = self.y0 27 t = t.to(self.dtype) ---> 28 self._before_integrate(t) 29 for i in range(1, len(t)): 30 solution[i] = self._advance(t[i])
File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\rk_common.py:163, in RKAdaptiveStepsizeODESolver._before_integrate(self, t) 161 f0 = self.func(t[0], self.y0) 162 if self.first_step is None: --> 163 first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol, 164 self.norm, f0=f0) 165 else: 166 first_step = self.first_step
File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\misc.py:62, in _select_initial_step(func, t0, y0, order, rtol, atol, norm, f0) 59 h0 = 0.01 d0 / d1 61 y1 = y0 + h0 f0 ---> 62 f1 = func(t0 + h0, y1) 64 d2 = norm((f1 - f0) / scale) / h0 66 if d1 <= 1e-15 and d2 <= 1e-15:
File ~\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\misc.py:189, in _PerturbFunc.forward(self, t, y, perturb) 186 else: 187 # Do nothing. 188 pass --> 189 return self.base_func(t, y)
File ~\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~\anaconda3\envs\torch\lib\site-packages\torchdiffeq_impl\misc.py:189, in _PerturbFunc.forward(self, t, y, perturb) 186 else: 187 # Do nothing. 188 pass --> 189 return self.base_func(t, y)
File ~\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []
Input In [8], in sysModel.forward(self, t, y) 62 x, z = torch.split(y, [1, 1], dim=0) 63 # print(x, y, z) ---> 64 u = control(t, x) 66 op = torch.cat((-9.82 * torch.abs(u), z), dim=0) 67 return op
File ~\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []
Input In [8], in controlFun.forward(self, t, q) 46 with torch.enable_grad(): 47 # q, p = torch.split(y, [1, 1], dim=0) 48 v_q = self.vnn(q.unsqueeze(dim=0)) ---> 50 dv = torch.autograd.grad(v_q, q, create_graph=True, retain_graph=True)[0] 52 return -dv
File ~\anaconda3\envs\torch\lib\site-packages\torch\autogradinit.py:276, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched) 274 return _vmap_internals.vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs) 275 else: --> 276 return Variable.execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 277 t_outputs, grad_outputs, retain_graph, create_graph, t_inputs, 278 allow_unused, accumulate_grad=False)
RuntimeError: One of the differentiated Tensors does not require grad
Thanks!