This is the official implementation of Heavy Ball Neural Ordinary Differential Equations. For any questions about the code, please correspond to hedixia@ucla.edu. The code is based on Pytorch and torchdiffeq, and all default numerical solvers used in the experiments are dopri5.
Download walker2d data by
python data_download.py
Data format shape: [timestamps, batch, channels (derivatives), feature dimension]
Usage:
First create a NODE type module by
cell = NODE(...)
Or a HBNODE by
cell = HBNODE(...)
And turn it into a time series model by
model = NODEintegrate(cell)
It can also be used as a residual network analogy by
model = NODElayer(cell)
For NODE-RNN type hybrids, use
model = ODE_RNN(ode, cell, nhid, ic)
here nhid
is the hidden shape (same shape as ode / cell input and output). ic
is the initial conditions.
As Jupyter Notebooks:
As Python files