MetaSolver for Neural Ordinary Differential Equations
Towards robust neural ODEs using parametrized solvers.
Main idea
Each RungeKutta (RK) solver with s
stages and of the p
th order is defined by a table of coefficients (Butcher tableau). For s=p=2
, s=p=3
and s=p=4
all coefficient in the table can be parametrized with no more than two variables [1].
Usually, during neural ODE training RK solver with fixed Butcher tableau is used, and only the righthand side (RHS) function is trained. We propose to use the whole parametric family of RK solvers to improve robustness of neural ODEs.
Requirements
 pytorch==1.7
 apex==0.1 (for training)
Examples
For CIFAR10 and MNIST demo, please, check examples
folder.
Meta Solver Regimes
In the notebook examples/cifar10/Evaluate model.ipynb
we show how to perform the forward pass through the Neural ODE using different types of Meta Solver regimes, namely
 Standalone
 Solver switching/smoothing
 Solver ensembling
 Model ensembling
In more details, usage of different regimes means

Standalone
 Use one solver during inference.
 This regime is applied in the training and testing stages.

Solver switching / smoothing
 For each batch one solver is chosen from a group of solvers with finite (in switching regime) or infinite (in smoothing regime) number of candidates.
 This regime is applied in the training stage

Solver ensembling
 Use several solvers durung inference.
 Outputs of ODE Block (obtained with different solvers) are averaged before propagating through the next layer.
 This regime is applied in the training and testing stages.

Model ensembling
 Use several solvers durung inference.
 Model probabilites obtained via propagation with different solvers are averaged to get the final result.
 This regime is applied in the training and testing stages.
Selected results
Different solver parameterizations yield different robustness
We have trained a neural ODE model several times, using different u
values in parametrization of the 2nd order RungeKutta solver. The image below depicts robust accuracies for the MNIST classification task. We use PGD attack (eps=0.3, lr=2/255 and iters=7). The mean values of robust accuracy (bold lines) and + standard error mean (shaded region) computed across 9 random seeds are shown in this image.
Solver smoothing improves robustness
We compare results of neural ODE adversarial training on CIFAR10 dataset with and without solver smoothing (using normal distribution with mean = 0 and sigma=0.0125). We choose 8steps RK2 solver with u=0.5 for this experiment.
 We perform training using FGSM random technique described in https://arxiv.org/abs/2001.03994 (with eps=8/255, alpha=10/255).
 We use cyclic learning rate schedule with one cycle (36 epochs, max_lr=0.1, base_lr=1e7).
 We measure robust accuracy of resulting models after FGSM (eps=8/255) and PGD (eps=8/255, lr=2/255, iters=7) attacks.
 We use
premetanode10
architecture fromsopa/src/models/odenet_cifar10/layers.py
that has the following formConv > PreResNet block > ODE block > PreResNet block > ODE block > GeLU > Average Pooling > Fully Connected
 We compute mean and standard error across 3 random seeds.