(Undamped Independent Controlled Oscillatory RNN)
[ICML 2021]
UnICORNN This repository contains the implementation to reproduce the numerical experiments of the ICML 2021 paper UnICORNN: A recurrent model for learning very long time dependencies
Requirements
This code runs on GPUs only, as the recurrent part of UnICORNN is implemented directly in CUDA. The CUDA extension is compiled using pynvrtc. Make sure all of the packages below are installed.
python 3.7.4
cupy 7.6.0
pynvrtc 9.2
pytorch 1.5.1+cu101
torchvision 0.6.1+cu101
torchtext 0.6.0
numpy 1.17.3
spacy 2.3.2
Speed
The recurrent part of UnICORNN is directly implemented in pure CUDA (as a PyTorch extension to the remaining standard PyTorch code), where each dimension of the underlying dynamical system is computed on an independent CUDA thread. This leads to an amazing speed-up over using PyTorch on GPUs directly (depending on the data set around 30-50 times faster). Below is a speed comparison of our UnICORNN implementation to the fastest RNN implementations you can find (the set-up of this benchmark can be found in the main paper):
Datasets
This repository contains the codes to reproduce the results of the following experiments for the proposed UnICORNN:
- Permuted Sequential MNIST
- Noise-padded CIFAR10
- EigenWorms
- Healthcare AI: Respiratory rate (RR)
- Healthcare AI: Heart rate (HR)
- IMDB
Results
The results of the UnICORNN for each of the experiments are:
Experiment | Result |
psMNIST | 98.4% test accuracy |
Noise-padded CIFAR10 | 62.4% test accuarcy |
Eigenworms | 94.9% test accuracy |
Healthcare AI: RR | 1.00 L2 loss |
Healthcare AI: HR | 1.31 L2 loss |
IMDB | 88.4% test accuracy |
Citation
@inproceedings{pmlr-v139-rusch21a,
title = {UnICORNN: A recurrent model for learning very long time dependencies},
author = {Rusch, T. Konstantin and Mishra, Siddhartha},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
pages = {9168--9178},
year = {2021},
volume = {139},
series = {Proceedings of Machine Learning Research},
publisher = {PMLR},
}