What can linearized neural networks actually say about generalization?

Overview

What can linearized neural networks actually say about generalization?

This is the source code to reproduce the experiments of the NeurIPS 2021 paper "What can linearized neural networks actually say about generalization?" by Guillermo Ortiz-Jimenez, Seyed-Mohsen Moosavi-Dezfooli and Pascal Frossard.

Dependencies

To run the code, please install all its dependencies by running:

$ pip install -r requirements.txt

This assumes that you have access to a Linux machine with an NVIDIA GPU with CUDA>=11.1. Otherwise, please check the instructions to install JAX with your setup in the corresponding repository.

In general, all scripts are parameterized using hydra and their configuration files can be found in the config/ folder.

Experiments

The repository contains code to reproduce the following experiments:

Spectral decomposition of NTK

To generate our new benchmark, consisting on the eigenfunctions of the NTK at initialization, please run the python script compute_ntk.py selecting a desired model (e.g., mlp, lenet or resnet18) and supporting dataset (e.g., cifar10 or mnist). This can be done by running

$ python compute_ntk.py model=lenet data.dataset=cifar10

This script will save the eigenvalues, eigenfunctions and weights of the model under artifacts/eigenfunctions/{data.dataset}/{model}/.

For other configuration options, please consult the configuration file config/compute-ntk/config.yaml.

Warning

Take into account that, for large models, this computation can take very long. For example, it took us two days to compute the full eigenvalue decomposition of the NTK of one randomly initialized ResNet18 using 4 NVIDIA V100 GPUs. The estimation of eigenvectors for the MLP or the LeNet, on the other hand, can be done in a matter of minutes, depending on the number of GPUs available and the selected batch_size

Training on binary eigenfunctions

Once you have estimated the eigenfunctions of the NTK, you should be able to train on any of them. To that end, select the desired label_idx (i.e. eigenfunction index), model and dataset, and run

$ python train_ntk.py label_idx=100 model=lenet data.dataset=cifar10 linearize=False

You can choose to train with the original non-linear network, or its linear approximation by specifying your choice with the flag linearize. For the non-linear models, this script also computes the final alignment of the end NTK with the target function, which it stores under artifacts/eigenfunctions/{data.dataset}/{model}/alignment_plots/

To see the different supported training options, please consult the configuration file config/train-ntk/config.yaml.

Estimation of NADs

We also provide code to compute the NADs of a CNN architecture (e.g., lenet or resnet18) using the alignment with the NTK at initialization. To do so, please run

$ python compute_nads.py model=lenet

This script will save the eigenvalues, NADs and weights of the model under artifacts/nads/{model}/.

For other configuration options, please consult the configuration file config/compute-nads/config.yaml.

Training on linearly separable datasets

Once you have estimated the NADs of a network, you should be able to train on linearly separable datasets with a single NAD as discriminative feature. To that end, select the desired label_idx (i.e. NAD index) and model, and run

$ python train_nads.py label_idx=100 model=lenet linearize=False

You can choose to train with the original non-linear network, or its linear approximation by specifying your choice with the flag linearize.

To see the different supported training options, please consult the configuration file config/train-nads/config.yaml.

Comparison of training dynamics with pretrained NTK

We also provide code to compare the training dynamics of the linearize network at initialization, and after non-linear pretraining, to estimate a particular eigenfunction of the NTK at initialization. To do this, please run

$ python pretrained_ntk_comparison.py label_idx=100 model=lenet data.dataset=cifar10

To see the different supported training options, please consult the configuration file config/pretrained_ntk_comparison/config.yaml.

Training on CIFAR2

Finally, you can train a neural network and its linearize approximation on the binary version of CIFAR10, i.e., CIFAR2. To do this, please run

$ python train_cifar.py model=lenet linearize=False

To see the different supported training options, please consult the configuration file config/binary-cifar/config.yaml.

Reference

If you use this code, please cite the following paper:

@InCollection{Ortiz-JimenezNeurIPS2021,
  title = {What can linearized neural networks actually say about generalization?},
  author = {{Ortiz-Jimenez}, Guillermo and {Moosavi-Dezfooli}, Seyed-Mohsen and Frossard, Pascal},
  booktitle = {Advances in Neural Information Processing Systems 35},
  month = Dec,
  year = {2021}
}
You might also like...
A list of papers regarding generalization in (deep) reinforcement learning

A list of papers regarding generalization in (deep) reinforcement learning

Domain Generalization with MixStyle, ICLR'21.

MixStyle This repo contains the code of our ICLR'21 paper, "Domain Generalization with MixStyle". The OpenReview link is https://openreview.net/forum?

Benchmarks for semi-supervised domain generalization.

Semi-Supervised Domain Generalization This code is the official implementation of the following paper: Semi-Supervised Domain Generalization with Stoc

Sharpness-Aware Minimization for Efficiently Improving Generalization
Sharpness-Aware Minimization for Efficiently Improving Generalization

Sharpness-Aware-Minimization-TensorFlow This repository provides a minimal implementation of sharpness-aware minimization (SAM) (Sharpness-Aware Minim

 PyTorch evaluation code for Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.
PyTorch evaluation code for Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.

Out-of-distribution Generalization Investigation on Vision Transformers This repository contains PyTorch evaluation code for Delving Deep into the Gen

Official implementation of paper Gradient Matching for Domain Generalization
Official implementation of paper Gradient Matching for Domain Generalization

Gradient Matching for Domain Generalisation This is the official PyTorch implementation of Gradient Matching for Domain Generalisation. In our paper,

CrossNorm and SelfNorm for Generalization under Distribution Shifts (ICCV 2021)
CrossNorm and SelfNorm for Generalization under Distribution Shifts (ICCV 2021)

CrossNorm (CN) and SelfNorm (SN) (Accepted at ICCV 2021) This is the official PyTorch implementation of our CNSN paper, in which we propose CrossNorm

CrossNorm and SelfNorm for Generalization under Distribution Shifts (ICCV 2021)
CrossNorm and SelfNorm for Generalization under Distribution Shifts (ICCV 2021)

CrossNorm (CN) and SelfNorm (SN) (Accepted at ICCV 2021) This is the official PyTorch implementation of our CNSN paper, in which we propose CrossNorm

Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization
Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization

Fishr: Invariant Gradient Variances for Out-of-distribution Generalization Official PyTorch implementation of the Fishr regularization for out-of-dist

Comments
Owner
gortizji
PhD student at EPFL
gortizji
Code used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks"

Train longer, generalize better - Big batch training This is a code repository used to generate the results appearing in "Train longer, generalize bet

Elad Hoffer 145 Sep 16, 2022
Code for paper "Which Training Methods for GANs do actually Converge? (ICML 2018)"

GAN stability This repository contains the experiments in the supplementary material for the paper Which Training Methods for GANs do actually Converg

Lars Mescheder 885 Jan 1, 2023
Complex-Valued Neural Networks (CVNN)Complex-Valued Neural Networks (CVNN)

Complex-Valued Neural Networks (CVNN) Done by @NEGU93 - J. Agustin Barrachina Using this library, the only difference with a Tensorflow code is that y

youceF 1 Nov 12, 2021
This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CNPs), Neural Processes (NPs), Attentive Neural Processes (ANPs).

The Neural Process Family This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CN

DeepMind 892 Dec 28, 2022
Library for machine learning stacking generalization.

stacked_generalization Implemented machine learning *stacking technic[1]* as handy library in Python. Feature weighted linear stacking is also availab

null 114 Jul 19, 2022
A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization

sam.pytorch A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization ( Foret+2020) Paper, Official implementa

Ryuichiro Hataya 102 Dec 28, 2022
[CVPR'21] FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space

FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space by Quande Liu, Cheng Chen, Ji

Quande Liu 178 Jan 6, 2023
The code release of paper 'Domain Generalization for Medical Imaging Classification with Linear-Dependency Regularization' NIPS 2020.

Domain Generalization for Medical Imaging Classification with Linear Dependency Regularization The code release of paper 'Domain Generalization for Me

Yufei Wang 56 Dec 28, 2022