DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.4k stars 129 forks source link

Number of steps of adaptive solver #131

Open shekshaa opened 2 years ago

shekshaa commented 2 years ago

Hi, First of all thank you for this nice library. I really enjoyed using it. I was wondering how can I figure out the number of steps taken by an adaptive solver at each forward process of an ODE? First I thought that I should monitor the vf.nfe attribute of the NeuralODE class but its value increases monotonically so I assume that it keeps the total number of function evaluation.

Zymrael commented 2 years ago

Hi @shekshaa! Normally odeint calls return t_eval, sol, where t_eval is of the same length as your desired evaluation t_span. However, you can do the following:

f = lambda t, x: torch.sin(t)
t_span = torch.linspace(0, 1, 10)
x = torch.ones(1,1)
t_eval, sol = odeint(f, x, t_span, solver='dopri5', return_all_eval=True)

The flag return_all_eval will force odeint to return all evaluation points and corresponding x(t) it needed to produce the solution. The above will give you len(t_eval) != len(t_span), which is exactly what we want!

NFEs are another possibility; remember to reset vf.nfe = 0 after every call. The meaning is slightly different, with NFEs you will have an estimate of the number of times the vector field has been called within odeint, not the number of steps, but roughly solver_order * n_steps. Some solvers call f multiple times for a single step.

shekshaa commented 2 years ago

Thanks for the clarification. I will take the first approach. Btw, I suggest that you add an attribute like return_all_eval to NeuralODE class (and perhaps other similar classes), so that when calling forward, it could decide if it returns all evaluation points or not. This is helpful especially when one wants to monitor the change in number steps taken by the adaptive solvers during training.