rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.61k stars 930 forks source link

How to pass extra paramaters of func to odeint? #246

Open shifttttttt opened 10 months ago

shifttttttt commented 10 months ago

I look the defination of deint in torchdiffeq , but do not find a paramater to pass extra paramaters like the args paramater in scipy.integrate.odeint. Is there any other way to pass paramaters to odeint besides define a global variable?

rtqichen commented 8 months ago

Yeah, just define it anywhere. In order to use odeint_adjoint, it's good practice to define them as part of the module.


global_params = ...

class ODEfunc(nn.Module):

  def __init__(self):
    self.parameters = nn.Parameter(some_tensor_we_want_to_optimize_ie_compute_gradients_for)

  def forward(self, t, x):
    p = self.parameters()
    external_p = global_params
    # some ops regarding t, x, p, external_p
    return ...

If you use odeint, gradient will be computed w.r.t. external_p, but odeint_adjoint will only do it for p.

shifttttttt commented 8 months ago

Thanks for your answer!

HoangMinhPhan commented 2 months ago

Yeah, just define it anywhere. In order to use odeint_adjoint, it's good practice to define them as part of the module.


global_params = ...

class ODEfunc(nn.Module):

  def __init__(self):
    self.parameters = nn.Parameter(some_tensor_we_want_to_optimize_ie_compute_gradients_for)

  def forward(self, t, x):
    p = self.parameters()
    external_p = global_params
    # some ops regarding t, x, p, external_p
    return ...

If you use odeint, gradient will be computed w.r.t. external_p, but odeint_adjoint will only do it for p.

Hi can you elaborate a bit why the gradient (in case using odeint) will be computed w.r.t external_p if we define the external_p from a global_params? I thought that, as long as we set requires_grad=False, the parameter (with requires_grad=False) would not be involved in the gradient calculation (here the case is at loss.backward() ). Thank you very much