DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.38k stars 128 forks source link

Hypersolvers missing `args` kwarg #184

Closed jcallaham closed 1 year ago

jcallaham commented 1 year ago

In the Quickstart Jupyter Notebook, under hypersolvers, running this block:

net = nn.Sequential(nn.Linear(3, 64), nn.Softplus(), nn.Linear(64, 64), nn.Softplus(), nn.Linear(64, 3))
hypersolver = HyperEuler(VanillaHyperNet(net))
t_eval, sol = odeint(sys, x0, t_span, solver=hypersolver) # note: this has to be trained!

Produces this error:

File torchdyn/numerics/odeint.py:428, in _fixed_odeint(f, x, t_span, solver, save_at, args)
    426 steps = 1
    427 while steps <= len(t_span) - 1:
--> 428     _, x, _ = solver.step(f, x, t, dt, k1=None, args=args)
    429     t = t + dt
    431     if torch.isclose(t, save_at).sum():

TypeError: HyperEuler.step() got an unexpected keyword argument 'args'

Looks like this was changed for the "normal" ODE solvers in cb83c8c, but not in hyper.py. I can put up a PR to fix.