Learning Neural Network Subspaces

Overview

Learning Neural Network Subspaces

Welcome to the codebase for Learning Neural Network Subspaces by Mitchell Wortsman, Maxwell Horton, Carlos Guestrin, Ali Farhadi, Mohammad Rastegari.

Figure1

Abstract

Recent observations have advanced our understanding of the neural network optimization landscape, revealing the existence of (1) paths of high accuracy containing diverse solutions and (2) wider minima offering improved performance. Previous methods observing diverse paths require multiple training runs. In contrast we aim to leverage both property (1) and (2) with a single method and in a single training run. With a similar computational cost as training one model, we learn lines, curves, and simplexes of high-accuracy neural networks. These neural network subspaces contain diverse solutions that can be ensembled, approaching the ensemble performance of independently trained networks without the training cost. Moreover, using the subspace midpoint boosts accuracy, calibration, and robustness to label noise, outperforming Stochastic Weight Averaging.

Code Overview

In this repository we walk through learning neural network subspaces with PyTorch. We will ground the discussion with learning a line of neural networks. In our code, a line is defined by endpoints weight and weight1 and a point on the line is given by w = (1 - alpha) * weight + alpha * weight1 for some alpha in [0,1].

Algorithm 1 (see paper) works as follows:

  1. weight and weight1 are initialized independently.
  2. For each batch data, targets, alpha is chosen uniformly from [0,1] and the weights w = (1 - alpha) * weight + alpha * weight1 are used in the forward pass.
  3. The regularization term is computed (see Eq. 3).
  4. With loss.backward() and optimizer.step() the endpoints weight and weight1 are updated.

Instead of using a regular nn.Conv2d we instead use a SubspaceConv (found in modes/modules.py).

class SubspaceConv(nn.Conv2d):
    def forward(self, x):
        w = self.get_weight()
        x = F.conv2d(
            x,
            w,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )
        return x

For each subspace type (lines, curves, and simplexes) the function get_weight must be implemented. For lines we use:

class TwoParamConv(SubspaceConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight1 = nn.Parameter(torch.zeros_like(self.weight))

    def initialize(self, initialize_fn):
        initialize_fn(self.weight1)

class LinesConv(TwoParamConv):
    def get_weight(self):
        w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
        return w

Note that the other endpoint weight is instantiated and initialized by nn.Conv2d. Also note that there is an equivalent implementation for batch norm layers also found in modes/modules.py.

Now we turn to the training logic which appears in trainers/train_one_dim_subspaces.py. In the snippet below we assume we are not training with the layerwise variant (args.layerwise = False) and we are drawing only one sample from the subspace (args.num_samples = 1).

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(args.device), target.to(args.device)

    alpha = np.random.uniform(0, 1)
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
            setattr(m, f"alpha", alpha)

    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)

All that's left is to compute the regularization term and call backward. For lines, this is given by the snippet below.

    num = 0.0
    norm = 0.0
    norm1 = 0.0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            num += (self.weight * self.weight1).sum()
            norm += self.weight.pow(2).sum()
            norm1 += self.weight1.pow(2).sum()
    loss += args.beta * (num.pow(2) / (norm * norm1))

    loss.backward()

    optimizer.step()

Training Lines, Curves, and Simplexes

We now walkthrough generating the plots in Figures 4 and 5 of the paper. Before running code please install PyTorch and Tensorboard (for making plots you will also need tex on your computer). Note that this repository differs from that used to generate the figures in the paper, as the latter leveraged Apple's internal tools. Accordingly there may be some bugs and we encourage you to submit an issue or send an email if you run into any problems.

In this example walkthrough we consider TinyImageNet, which we download to ~/data using a script such as this. To run standard training and ensemble the trained models, use the following command:

python experiment_configs/tinyimagenet/ensembles/train_ensemble_members.py
python experiment_configs/tinyimagenet/ensembles/eval_ensembles.py

Note that if your data is not in ~/data please change the paths in these experiment configs. Logs and checkpoints be saved in learning-subspaces-results, although this path can also be changed.

For one dimensional subspaces, use the following command to train:

python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_lines.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_lines_layerwise.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_curves.py

To evaluate (i.e. generate the data for Figure 4) use:

python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_lines.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_lines_layerwise.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_curves.py

We recommend looking at the experiment config files before running, which can be modified to change the type of model, number of random seeds. The default in these configs is 2 random seeds.

Analogously, to train simplexes use:

python experiment_configs/tinyimagenet/simplexes/train_simplexes.py
python experiment_configs/tinyimagenet/simplexes/train_simplexes_layerwise.py

For generating plots like those in Figure 4 and 5 use:

python analyze_results/tinyimagenet/one_dimensional_subspaces.py
python analyze_results/tinyimagenet/simplexes.py

Equivalent configs exist for other datasets, and the configs can be modified to add label noise, experiment with other models, and more. Also, if there is any functionality missing from this repository that you would like please also submit an issue.

Bibtex

@article{wortsman2021learning,
  title={Learning Neural Network Subspaces},
  author={Wortsman, Mitchell and Horton, Maxwell and Guestrin, Carlos and Farhadi, Ali and Rastegari, Mohammad},
  journal={arXiv preprint arXiv:2102.10472},
  year={2021}
}
You might also like...
Code repo for
Code repo for "RBSRICNN: Raw Burst Super-Resolution through Iterative Convolutional Neural Network" (Machine Learning and the Physical Sciences workshop in NeurIPS 2021).

RBSRICNN: Raw Burst Super-Resolution through Iterative Convolutional Neural Network An official PyTorch implementation of the RBSRICNN network as desc

Learning Versatile Neural Architectures by Propagating Network Codes
Learning Versatile Neural Architectures by Propagating Network Codes

Learning Versatile Neural Architectures by Propagating Network Codes Mingyu Ding, Yuqi Huo, Haoyu Lu, Linjie Yang, Zhe Wang, Zhiwu Lu, Jingdong Wang,

A foreign language learning aid using a neural network to predict probability of translating foreign words
A foreign language learning aid using a neural network to predict probability of translating foreign words

Langy Langy is a reading-focused foreign language learning aid orientated towards young children. Reading is an activity that every child knows. It is

BasicNeuralNetwork - This project looks over the basic structure of a neural network and how machine learning training algorithms work
BasicNeuralNetwork - This project looks over the basic structure of a neural network and how machine learning training algorithms work

BasicNeuralNetwork - This project looks over the basic structure of a neural network and how machine learning training algorithms work. For this project, I used the sigmoid function as an activation function along with stochastic gradient descent to adjust the weights and biases.

TJU Deep Learning & Neural Network

Deep_Learning & Neural_Network_Lab 实验环境 Python 3.9 Anaconda3(官网下载或清华镜像都行) PyTorch 1.10.1(安装代码如下) conda install pytorch torchvision torchaudio cudatool

Neuralnetwork - Basic Multilayer Perceptron Neural Network for deep learning

Neural Network Just a basic Neural Network module Usage Example Importing Module

Codes and models for the paper "Learning Unknown from Correlations: Graph Neural Network for Inter-novel-protein Interaction Prediction".

GNN_PPI Codes and models for the paper "Learning Unknown from Correlations: Graph Neural Network for Inter-novel-protein Interaction Prediction". Lear

Code for
Code for "Neural Parts: Learning Expressive 3D Shape Abstractions with Invertible Neural Networks", CVPR 2021

Neural Parts: Learning Expressive 3D Shape Abstractions with Invertible Neural Networks This repository contains the code that accompanies our CVPR 20

A lightweight Python-based 3D network multi-agent simulator. Uses a cell-based congestion model. Calculates risk, loudness and battery capacities of the agents. Suitable for 3D network optimization tasks.
A lightweight Python-based 3D network multi-agent simulator. Uses a cell-based congestion model. Calculates risk, loudness and battery capacities of the agents. Suitable for 3D network optimization tasks.

AMAZ3DSim AMAZ3DSim is a lightweight python-based 3D network multi-agent simulator. It uses a cell-based congestion model. It calculates risk, battery

Comments
  • Are trained models available ?

    Are trained models available ?

    Hi,

    great job on the paper and code :) I was hoping you could share your trained models (esp. on the Imagenet dataset). It would be super appreciated.

    Thanks! Lucas

    opened by pclucas14 8
  • Implementation of other neuron types

    Implementation of other neuron types

    I read the paper and found this approach very interesting. However, I do not see any implementation of other neuron types (e.g., FullyConnected or LSTM).

    So, I'm a little curious. Have you guys done any experiment regarding those layers? Would those kind of layers work as well?

    Edit1. As I went through the code, I see that only the weights of the conv are being sampled but not the biases. Why is that the case?

    opened by 51616 2
  • Simple variants for `nn.Embedding` and `nn.LSTM`

    Simple variants for `nn.Embedding` and `nn.LSTM`

    Hi, I really enjoyed your work and I am the first one who cited your work! (https://arxiv.org/abs/2109.07628)

    As I made some simple variants of your method for language models, could you please check this out? I think it is not much to be a PR, so I made an issue instead for the request of simple checks from original authors.

    Could you please check? Thank you in advance.

    Best, Adam


    # LSTM layer
    class SubspaceLSTM(nn.LSTM):
        def forward(self, x):
            # call get_weight, which samples from the subspace, then use the corresponding weight.
            weight_dict = self.get_weight()
            mixed_lstm = nn.LSTM(
                input_size=self.input_size, 
                hidden_size=self.hidden_size, 
                num_layers=self.num_layers, 
                batch_first=self.batch_first
            )
            for l in range(self.num_layers):
                setattr(mixed_lstm, f'weight_hh_l{l}', nn.Parameter(weight_dict[f'weight_hh_l{l}_mixed']))
                setattr(mixed_lstm, f'weight_ih_l{l}', nn.Parameter(weight_dict[f'weight_ih_l{l}_mixed']))
                if self.bias:
                    setattr(mixed_lstm, f'bias_hh_l{l}', nn.Parameter(weight_dict[f'bias_hh_l{l}_mixed']))
                    setattr(mixed_lstm, f'bias_ih_l{l}', nn.Parameter(weight_dict[f'bias_ih_l{l}_mixed']))
            return mixed_lstm(x)
    
    class TwoParamLSTM(SubspaceLSTM):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            for l in range(self.num_layers):
                setattr(self, f'weight_hh_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'weight_hh_l{l}'))))
                setattr(self, f'weight_ih_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'weight_ih_l{l}'))))
                if self.bias:
                    setattr(self, f'bias_hh_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'bias_hh_l{l}'))))
                    setattr(self, f'bias_ih_l{l}_1', nn.Parameter(torch.zeros_like(getattr(self, f'bias_ih_l{l}'))))
                    
    class LinesLSTM(TwoParamLSTM):
        def get_weight(self):
            weight_dict = dict()
            for l in range(self.num_layers):
                weight_dict[f'weight_hh_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'weight_hh_l{l}') + self.alpha * getattr(self, f'weight_hh_l{l}_1') 
                weight_dict[f'weight_ih_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'weight_ih_l{l}') + self.alpha * getattr(self, f'weight_ih_l{l}_1') 
                if self.bias:
                    weight_dict[f'bias_hh_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'bias_hh_l{l}') + self.alpha * getattr(self, f'bias_hh_l{l}_1') 
                    weight_dict[f'bias_ih_l{l}_mixed'] = (1 - self.alpha) * getattr(self, f'bias_ih_l{l}') + self.alpha * getattr(self, f'bias_ih_l{l}_1')
            return weight_dict
    
    # Embedding layer
    class SubspaceEmbedding(nn.Embedding):
        def forward(self, x):
            w = self.get_weight()
            x = F.embedding(
                x,
                w,
            )
            return x
    
    class TwoParamEmbedding(SubspaceEmbedding):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.weight1 = nn.Parameter(torch.zeros_like(self.weight))
                    
    class LinesEmbedding(TwoParamEmbedding):
        def get_weight(self):
            w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
            return w
    
    opened by vaseline555 4
Owner
Apple
Apple
This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CNPs), Neural Processes (NPs), Attentive Neural Processes (ANPs).

The Neural Process Family This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CN

DeepMind 892 Dec 28, 2022
Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks

Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks. Bayesian-Torch is designed to be flexible and seamless in extending a deterministic deep neural network architecture to corresponding Bayesian form by simply replacing the deterministic layers with Bayesian layers.

Intel Labs 210 Jan 4, 2023
Neural-net-from-scratch - A simple Neural Network from scratch in Python using the Pymathrix library

A Simple Neural Network from scratch A Simple Neural Network from scratch in Pyt

Youssef Chafiqui 2 Jan 7, 2022
Visualizer for neural network, deep learning, and machine learning models

Netron is a viewer for neural network, deep learning and machine learning models. Netron supports ONNX (.onnx, .pb, .pbtxt), Keras (.h5, .keras), Tens

Lutz Roeder 21k Jan 6, 2023
Deep learning (neural network) based remote photoplethysmography: how to extract pulse signal from video using deep learning tools

Deep-rPPG: Camera-based pulse estimation using deep learning tools Deep learning (neural network) based remote photoplethysmography: how to extract pu

Terbe Dániel 138 Dec 17, 2022
PyTorch Code of "Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics"

Memory In Memory Networks It is based on the paper Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spati

Yang Li 12 May 30, 2022
PPLNN is a Primitive Library for Neural Network is a high-performance deep-learning inference engine for efficient AI inferencing

PPLNN is a Primitive Library for Neural Network is a high-performance deep-learning inference engine for efficient AI inferencing

null 943 Jan 7, 2023
TilinGNN: Learning to Tile with Self-Supervised Graph Neural Network (SIGGRAPH 2020)

TilinGNN: Learning to Tile with Self-Supervised Graph Neural Network (SIGGRAPH 2020) About The goal of our research problem is illustrated below: give

null 59 Dec 9, 2022
Scripts for training an AI to play the endless runner Subway Surfers using a supervised machine learning approach by imitation and a convolutional neural network (CNN) for image classification

About subwAI subwAI - a project for training an AI to play the endless runner Subway Surfers using a supervised machine learning approach by imitation

null 82 Jan 1, 2023
This is an official PyTorch implementation of Task-Adaptive Neural Network Search with Meta-Contrastive Learning (NeurIPS 2021, Spotlight).

NeurIPS 2021 (Spotlight): Task-Adaptive Neural Network Search with Meta-Contrastive Learning This is an official PyTorch implementation of Task-Adapti

Wonyong Jeong 15 Nov 21, 2022