tk-rusch / unicornn

Official code for UnICORNN (ICML 2021)
27 stars 3 forks source link
cuda deep-learning icml icml-2021 imdb pytorch recurrent-neural-networks sequential-mnist time-series

UnICORNN
(Undamped Independent Controlled Oscillatory RNN)
[ICML 2021]

This repository contains the implementation to reproduce the numerical experiments of the ICML 2021 paper UnICORNN: A recurrent model for learning very long time dependencies

Requirements

This code runs on GPUs only, as the recurrent part of UnICORNN is implemented directly in CUDA. The CUDA extension is compiled using pynvrtc. Make sure all of the packages below are installed.

python 3.7.4
cupy 7.6.0
pynvrtc 9.2
pytorch 1.5.1+cu101 
torchvision 0.6.1+cu101
torchtext 0.6.0
numpy 1.17.3
spacy 2.3.2

Comment: If you are using cupy 9.0.0+ you can compile the CUDA extension (i.e. UnICORNN_CODE in network.py) directly with cupy.RawModule, which makes the use of pynvrtc obsolet, by simply writing mod = cupy.RawModule(code=UnICORNN_CODE, options=('--std=c++11',), name_expressions=('unicornn_fwd', 'unicornn_bwd')) and deleting the pynvrtc parts in the class UnICORNN_compile().

Speed

The recurrent part of UnICORNN is directly implemented in pure CUDA (as a PyTorch extension to the remaining standard PyTorch code), where each dimension of the underlying dynamical system is computed on an independent CUDA thread. This leads to an amazing speed-up over using PyTorch on GPUs directly (depending on the data set around 30-50 times faster). Below is a speed comparison of our UnICORNN implementation to the fastest RNN implementations you can find (the set-up of this benchmark can be found in the main paper):

Datasets

This repository contains the codes to reproduce the results of the following experiments for the proposed UnICORNN:

Results

The results of the UnICORNN for each of the experiments are:

Experiment Result
psMNIST 98.4% test accuracy
Noise-padded CIFAR10 62.4% test accuarcy
Eigenworms 94.9% test accuracy
Healthcare AI: RR 1.00 L2 loss
Healthcare AI: HR 1.31 L2 loss
IMDB 88.4% test accuracy

Citation

If you found this work useful, please consider citing

@inproceedings{pmlr-v139-rusch21a,
  title =    {UnICORNN: A recurrent model for learning very long time dependencies},
  author =       {Rusch, T. Konstantin and Mishra, Siddhartha},
  booktitle =    {Proceedings of the 38th International Conference on Machine Learning},
  pages =    {9168--9178},
  year =     {2021},
  volume =   {139},
  series =   {Proceedings of Machine Learning Research},
  publisher =    {PMLR},
}