hedixia / HeavyBallNODE

13 stars 2 forks source link

Heavy Ball Neural Ordinary Differential Equations

This is the official implementation of Heavy Ball Neural Ordinary Differential Equations. For any questions about the code, please correspond to hedixia@ucla.edu. The code is based on Pytorch and torchdiffeq, and all default numerical solvers used in the experiments are dopri5.

Usage

Download walker2d data by

python data_download.py

Data format shape: [timestamps, batch, channels (derivatives), feature dimension]

Usage:

First create a NODE type module by

cell = NODE(...)

Or a HBNODE by

cell = HBNODE(...)

And turn it into a time series model by

model = NODEintegrate(cell)

It can also be used as a residual network analogy by

model = NODElayer(cell)

For NODE-RNN type hybrids, use

model = ODE_RNN(ode, cell, nhid, ic)

here nhid is the hidden shape (same shape as ode / cell input and output). ic is the initial conditions.

Experiments

As Jupyter Notebooks:

As Python files