Open shifttttttt opened 10 months ago
Yeah, just define it anywhere. In order to use odeint_adjoint
, it's good practice to define them as part of the module.
global_params = ...
class ODEfunc(nn.Module):
def __init__(self):
self.parameters = nn.Parameter(some_tensor_we_want_to_optimize_ie_compute_gradients_for)
def forward(self, t, x):
p = self.parameters()
external_p = global_params
# some ops regarding t, x, p, external_p
return ...
If you use odeint
, gradient will be computed w.r.t. external_p
, but odeint_adjoint
will only do it for p
.
Thanks for your answer!
Yeah, just define it anywhere. In order to use
odeint_adjoint
, it's good practice to define them as part of the module.global_params = ... class ODEfunc(nn.Module): def __init__(self): self.parameters = nn.Parameter(some_tensor_we_want_to_optimize_ie_compute_gradients_for) def forward(self, t, x): p = self.parameters() external_p = global_params # some ops regarding t, x, p, external_p return ...
If you use
odeint
, gradient will be computed w.r.t.external_p
, butodeint_adjoint
will only do it forp
.
Hi can you elaborate a bit why the gradient (in case using odeint
) will be computed w.r.t external_p
if we define the external_p
from a global_params
? I thought that, as long as we set requires_grad=False
, the parameter (with requires_grad=False) would not be involved in the gradient calculation (here the case is at loss.backward()
).
Thank you very much
I look the defination of deint in torchdiffeq , but do not find a paramater to pass extra paramaters like the args paramater in scipy.integrate.odeint. Is there any other way to pass paramaters to odeint besides define a global variable?