rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.61k stars 930 forks source link

RuntimeError in odeint_adjoint #226

Open chooron opened 1 year ago

chooron commented 1 year ago

Hello, I have run my code by using odeint successfully, however when I use the odeint_adjoint, it comes out the error: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn, here is my code:

class M50_Func(nn.Module):
    def __init__(self, ET_net, Q_net, params, interps, ode_lib='torchdiffeq'):
        super().__init__()
        self.f, self.Smax, self.Qmax, self.Df, self.Tmax, self.Tmin = params
        self.ET_net = ET_net
        self.ET_net.train()
        self.ode_lib = ode_lib
        self.Q_net = Q_net
        self.Q_net.train()
        self.precp_interp, self.temp_interp, self.lday_interp = interps

    def forward(self, t, S):
        from models.common_net import Ps, Pr, M, step_fct
        S_snow, S_water = S[0][0], S[0][1]
        precp = self.precp_interp.evaluate(t).to(torch.float32)
        temp = self.temp_interp.evaluate(t).to(torch.float32)
        lday = self.lday_interp.evaluate(t).to(torch.float32)
        # precp = torch.from_numpy(self.precp_interp(t.numpy()).astype(np.float32)).to(device)
        # temp = torch.from_numpy(self.temp_interp(t.numpy()).astype(np.float32)).to(device)
        # lday = torch.from_numpy(self.lday_interp(t.numpy()).astype(np.float32)).to(device)
        ET_output = self.ET_net(torch.tensor([S_snow, S_water, temp]))
        Q_output = self.Q_net(torch.tensor([S_water, precp]))

        melt_output = M(S_snow, temp, self.Df, self.Tmax)
        dS_1 = Ps(precp, temp, self.Tmin) - melt_output
        dS_2 = Pr(precp, temp, self.Tmin) + melt_output - step_fct(S_water) * lday * torch.exp(
            ET_output) - step_fct(S_water) * torch.exp(Q_output)
        return torch.tensor([dS_1, dS_2]).unsqueeze(0)

class M50_Solver(BaseLearner):
    def __init__(self, solve_func: nn.Module, rtol=1e-6, atol=1e-6, ode_lib='torchdiffeq',
                 loss_metric=torch.nn.MSELoss(), eval_metric_list=None, lr=0.01, optimizer=None):
        super().__init__(solve_func, loss_metric, eval_metric_list, lr, optimizer)
        self.solve_func = solve_func
        self.solve_func.train()
        self.ode_lib = ode_lib
        self.rtol = rtol
        self.atol = atol

    def forward(self, x, t_eval):
        if len(x.shape) > 2:
            x = x[0]
        if len(t_eval.shape) > 1:
            t_eval = t_eval[0]
        t_eval = t_eval.to(torch.float32)
        y0 = torch.tensor([[x[0, 0], x[0, 1]]])
        sol = odeint_adjoint(self.solve_func, y0=y0, t=t_eval, rtol=self.rtol, atol=self.atol,
                             adjoint_options={"norm": "seminorm"})
        # adjoint_params=list(self.solve_func.ET_net.parameters())
        #                + list(self.solve_func.Q_net.parameters()))
        # sol = odeint(self.solve_func, y0=y0, t=t_eval, rtol=self.rtol, atol=self.atol)
        sol_1 = sol[:, 0, 1]
        y_hat = torch.exp(self.solve_func.Q_net(torch.concat([sol_1.unsqueeze(1), x[:, 2].unsqueeze(1)], dim=1)))
        return y_hat

The BaseLearner extends from the pytorch_lightning.LightningModule.

haonanhe commented 1 year ago

I also met the same problem... Have you solved it?