DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.4k stars 130 forks source link

Training stalls #135

Open alokwarey opened 2 years ago

alokwarey commented 2 years ago

I am noticing that the training just stalls/stops at a certain epoch for some reason. No errors/explanation. Screenshot below. It got stuck at this epoch and has been here for the last 2 hours. I am not sure why?

Untitled

Zymrael commented 2 years ago

This should be because of PyTorch Lightning, could you raise min_epochs to 100 or so?

alokwarey commented 2 years ago

Tried that. No change. Training stalled at epoch 30.

Zymrael commented 2 years ago

Could you share your model (and ideally the args passed to NeuralODE during init i.e. which solver are you using?) or is the setup the same as the tutorial? This could caused by an unlucky combination of adaptive solver + learning rate + integration time. One quick way to check if the problem is torchdyn or PyTorch Lightning related would be to replace the solver with a fixed-step alternative (use "euler" or "rk4") and see if the training stalls again.

alokwarey commented 2 years ago

Using a fixed step solver "euler" or "rk4" works! I was using tsit5 with Adam(lr = 0.01). I have a follow-up question. How can I solve an IVP or time series problem, but with control variables at each time step?

Instead of training a neural diffeq of the form dy/dt = f(y(t)), I want to train one of the form dy/dt = f(y(t), x(t)). How can I feed in x(t) or control variables at each time step for each mini-batch? Is there an example of a time series prediction problem with control variables?

laserkelvin commented 2 years ago

I've also encountered a similar issue as @alokwarey; changing to a fixed-step integrator seems to alleviate the issue, but as far as I'm aware I can't change parameters associated with those solvers (i.e. min/max step size, etc.)

On a quasi-related note, is it possible to print/log diagnostics for the stiffness of the problem, perhaps between training steps? Since torchdyn has PyTorch Lightning awareness, one could just expose metrics at a higher level for users to use with PL logging functionality. If this is desirable I could take a crack at it?