DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
Apache License 2.0
1.33k stars 124 forks source link

Argument mismatch and hard-coded `return_all_eval` #195

Open cantabile-kwok opened 1 year ago

cantabile-kwok commented 1 year ago

There are mismatched arguments in problems.ODEProblem.odeint My torchdyn version is 1.0.3 Step to Reproduce I want to see how many steps did the adaptive dopri5 solver take, so I sought for return_all_eval argument according to issue https://github.com/DiffEqML/torchdyn/issues/131. Then I found the NeuralODE class does not provide such a keyword argument here, so after a little bit diving into the source code I decided to put args={'return_all_eval': True}. However, this still does not give the desired result. The code snippet is:

from torchdyn.core import NeuralODE
import torch
import torch.nn as nn

class VectorField(nn.Module):
    def __init__(self):
        super(VectorField, self).__init__()
        self.net = nn.Linear(2, 2)

    def forward(self, t, x):
        print(f"In VectorField, t is fed as {t}")
        return self.net(t+x)

vf = VectorField()
ode = NeuralODE(vf, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
time = torch.linspace(0, 1, 10)
initial = torch.randn(16, 20, 2)
eval_time, sol = ode(initial, time, args={'return_all_eval': True})

Then, I found the return_all_eval keyword is not actually passed into the numerics.odeint.odeint function. The signature of that function is

def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, nn.Module], atol:float=1e-3, rtol:float=1e-3,
           t_stops:Union[List, Tensor, None]=None, verbose:bool=False, interpolator:Union[str, Callable, None]=None, return_all_eval:bool=False,
           save_at:Union[List, Tensor]=(), args:Dict={}, seminorm:Tuple[bool, Union[int, None]]=(False, None)) -> Tuple[Tensor, Tensor]:

so you can see return_all_eval is explicitly passed, but in numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward it is hard-coded as False:

def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
            t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B, 
                                        False, maxiter, fine_steps, save_at)
            ctx.save_for_backward(sol, t_sol)
            return t_sol, sol

So, basically I don't have any chance to switch it on except changing the source code.

Another thing is the argument mismatch issue of the numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward function. When it is called from odeint like https://github.com/DiffEqML/torchdyn/blob/a0d0fc537e3da7ecdfb46ee02074fc33028b366a/torchdyn/core/problems.py#L85, the arguments are mismatched from the signature of that forward function. This means the save_at argument will actually be overwritten by a dict and the B (which I do not understand) argument is actually the true save_at. This so far has not caused any problems in my code but I don't believe this is an expected behavior. I suggest someone take a deep debug into the code to have a look.

Screenshots There is a traceback that shows the problem. image

Expected behavior

The return_all_eval option should be handled by user and control whether the ODE solver produces all the evaluation time slots. Also, there is a huge lack of documentation on the meaning of these arguments and the provided functionalities, e.g. it is not until I found that github issue did I realize that there is a way to return all the evaluation time stamps.