There are mismatched arguments in problems.ODEProblem.odeint
My torchdyn version is 1.0.3
Step to Reproduce
I want to see how many steps did the adaptive dopri5 solver take, so I sought for return_all_eval argument according to issue https://github.com/DiffEqML/torchdyn/issues/131. Then I found the NeuralODE class does not provide such a keyword argument here, so after a little bit diving into the source code I decided to put args={'return_all_eval': True}. However, this still does not give the desired result. The code snippet is:
from torchdyn.core import NeuralODE
import torch
import torch.nn as nn
class VectorField(nn.Module):
def __init__(self):
super(VectorField, self).__init__()
self.net = nn.Linear(2, 2)
def forward(self, t, x):
print(f"In VectorField, t is fed as {t}")
return self.net(t+x)
vf = VectorField()
ode = NeuralODE(vf, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
time = torch.linspace(0, 1, 10)
initial = torch.randn(16, 20, 2)
eval_time, sol = ode(initial, time, args={'return_all_eval': True})
print(sol.shape)
Then, I found the return_all_eval keyword is not actually passed into the numerics.odeint.odeint function. The signature of that function is
so you can see return_all_eval is explicitly passed, but in numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward it is hard-coded as False:
So, basically I don't have any chance to switch it on except changing the source code.
Another thing is the argument mismatch issue of the numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward function. When it is called from odeint like https://github.com/DiffEqML/torchdyn/blob/a0d0fc537e3da7ecdfb46ee02074fc33028b366a/torchdyn/core/problems.py#L85, the arguments are mismatched from the signature of that forward function. This means the save_at argument will actually be overwritten by a dict and the B (which I do not understand) argument is actually the true save_at. This so far has not caused any problems in my code but I don't believe this is an expected behavior. I suggest someone take a deep debug into the code to have a look.
Screenshots
There is a traceback that shows the problem.
Expected behavior
The return_all_eval option should be handled by user and control whether the ODE solver produces all the evaluation time slots.
Also, there is a huge lack of documentation on the meaning of these arguments and the provided functionalities, e.g. it is not until I found that github issue did I realize that there is a way to return all the evaluation time stamps.
There are mismatched arguments in
problems.ODEProblem.odeint
My torchdyn version is 1.0.3 Step to Reproduce I want to see how many steps did the adaptive dopri5 solver take, so I sought forreturn_all_eval
argument according to issue https://github.com/DiffEqML/torchdyn/issues/131. Then I found the NeuralODE class does not provide such a keyword argument here, so after a little bit diving into the source code I decided to putargs={'return_all_eval': True}
. However, this still does not give the desired result. The code snippet is:Then, I found the
return_all_eval
keyword is not actually passed into thenumerics.odeint.odeint
function. The signature of that function isso you can see
return_all_eval
is explicitly passed, but innumerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward
it is hard-coded asFalse
:So, basically I don't have any chance to switch it on except changing the source code.
Another thing is the argument mismatch issue of the
numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward
function. When it is called fromodeint
like https://github.com/DiffEqML/torchdyn/blob/a0d0fc537e3da7ecdfb46ee02074fc33028b366a/torchdyn/core/problems.py#L85, the arguments are mismatched from the signature of that forward function. This means thesave_at
argument will actually be overwritten by a dict and theB
(which I do not understand) argument is actually the truesave_at
. This so far has not caused any problems in my code but I don't believe this is an expected behavior. I suggest someone take a deep debug into the code to have a look.Screenshots There is a traceback that shows the problem.![image](https://github.com/DiffEqML/torchdyn/assets/58417810/7dc3ef3e-b15f-4e11-8fb8-e68f16480d7a)
Expected behavior
The
return_all_eval
option should be handled by user and control whether the ODE solver produces all the evaluation time slots. Also, there is a huge lack of documentation on the meaning of these arguments and the provided functionalities, e.g. it is not until I found that github issue did I realize that there is a way to return all the evaluation time stamps.