A tutorial on "Bayesian Compression for Deep Learning" published at NIPS (2017).

Overview

Code release for "Bayesian Compression for Deep Learning"

In "Bayesian Compression for Deep Learning" we adopt a Bayesian view for the compression of neural networks. By revisiting the connection between the minimum description length principle and variational inference we are able to achieve up to 700x compression and up to 50x speed up (CPU to sparse GPU) for neural networks.

We visualize the learning process in the following figures for a dense network with 300 and 100 connections. White color represents redundancy whereas red and blue represent positive and negative weights respectively.

First layer weights Second Layer weights
alt text alt text

For dense networks it is also simple to reconstruct input feature importance. We show this for a mask and 5 randomly chosen digits. alt text

Results

Model Method Error [%] Compression
after pruning
Compression after
precision reduction
LeNet-5-Caffe DC 0.7 6* -
DNS 0.9 55* -
SWS 1.0 100* -
Sparse VD 1.0 63* 228
BC-GNJ 1.0 108* 361
BC-GHS 1.0 156* 419
VGG BC-GNJ 8.6 14* 56
BC-GHS 9.0 18* 59

Usage

We provide an implementation in PyTorch for fully connected and convolutional layers for the group normal-Jeffreys prior (aka Group Variational Dropout) via:

import BayesianLayers

The layers can be then straightforwardly included eas follows:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            # activation
            self.relu = nn.ReLU()
            # layers
            self.fc1 = BayesianLayers.LinearGroupNJ(28 * 28, 300, clip_var=0.04)
            self.fc2 = BayesianLayers.LinearGroupNJ(300, 100)
            self.fc3 = BayesianLayers.LinearGroupNJ(100, 10)
            # layers including kl_divergence
            self.kl_list = [self.fc1, self.fc2, self.fc3]

        def forward(self, x):
            x = x.view(-1, 28 * 28)
            x = self.relu(self.fc1(x))
            x = self.relu(self.fc2(x))
            return self.fc3(x)

        def kl_divergence(self):
            KLD = 0
            for layer in self.kl_list:
                KLD += layer.kl_divergence()
            return KLD

The only additional effort is to include the KL-divergence in the objective. This is necessary if we want to the optimize the variational lower bound that leads to sparse solutions:

N = 60000.
discrimination_loss = nn.functional.cross_entropy

def objective(output, target, kl_divergence):
    discrimination_error = discrimination_loss(output, target)
    return discrimination_error + kl_divergence / N

Run an example

We provide a simple example, the LeNet-300-100 trained with the group normal-Jeffreys prior:

python example.py

Retraining a regular neural network

Instead of training a network from scratch we often need to compress an already existing network. In this case we can simply initialize the weights with those of the pretrained network:

    BayesianLayers.LinearGroupNJ(28*28, 300, init_weight=pretrained_weight, init_bias=pretrained_bias)

Reference

The paper "Bayesian Compression for Deep Learning" has been accepted to NIPS 2017. Please cite us:

@article{louizos2017bayesian,
  title={Bayesian Compression for Deep Learning},
  author={Louizos, Christos and Ullrich, Karen and Welling, Max},
  journal={Conference on Neural Information Processing Systems (NIPS)},
  year={2017}
}
Comments
  • Missing factor 0.5 in KL-divergence?

    Missing factor 0.5 in KL-divergence?

    The convolution layer seems to miss a factor 0.5 in front of the log variance term in the KL-divergence. The Dense layer does have this factor.

    https://github.com/KarenUllrich/Tutorial_BayesianCompressionForDL/blob/a7d3d83410788b2b8ebf76de948af07fbe4922e2/BayesianLayers.py#L264

    opened by jheek 3
  • Is it in need to prune bias?

    Is it in need to prune bias?

    Hi @KarenUllrich ,

    I find your code does not eliminate the corresponding bias when pruning weights, which may, I think, lead to a high performance but not necessarily real. Maybe I have missed some important parts of your code, or, the bias influence is not that essential?

    Looking forward to discussing with you!

    opened by remicongee 1
  • Do you have any tricks to choose threshold of every layer?

    Do you have any tricks to choose threshold of every layer?

    After the model have been trained by the group horseshoe with the half-cauchy scale priors,I want to pruning the parameters.But I find it's hard to find the pruning threshold. Do you have any trick to pick up the threshold?Thanks very much. @KarenUllrich

    opened by larenzhang 1
  • KL divergence approx for Linear appears to be wrong

    KL divergence approx for Linear appears to be wrong

    The Linear layer uses this:

    KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
    

    But the convolutional layers use this:

    KLD_element = -self.weight_logvar + 0.5 * (self.weight_logvar.exp().pow(2) + self.weight_mu.pow(2)) - 0.5
    

    The second appears to match equation 8 of the paper, so is it just a mistake in the Linear layer?

    opened by gngdb 1
  • Small error in the kl-divergence

    Small error in the kl-divergence

    Unless I am missing something there is a slight error in the kl_divergence() definition in the _ConvNdGroupNJ class.

    KLD_element = -self.weight_logvar + 0.5 * (self.weight_logvar.exp().pow(2) + self.weight_mu.pow(2)) - 0.5
    

    treats self.weight_logvar as if it was the log std instead of the log variance. The correct expression should be (as in LinearGroupNJ):

    KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
    
    opened by manuelhaussmann 1
  • Anyone encounter loss nan?

    Anyone encounter loss nan?

    It is ok for small scale models like MLP/LeNet5. However, when it comes to vgg16/ resnet18, it will always produce nan loss. The model structure configuration is below:

    ` cfg = { 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], }

    class VGG_CIFAR10_BAY(nn.Module): kl_list = [] def init(self, vgg_name): super(VGG_CIFAR10_BAY, self).init() self.features = self._make_layers(cfg[vgg_name]) linear_index = BayesianLayer.LinearGroupNJ(512, 10, clip_var=0.04, cuda=True) self.classifier = linear_index self.kl_list.append(linear_index)

    def forward(self, x): out = self.features(x) out = out.view(out.size(0), -1) out = self.classifier(out) return out

    def _make_layers(self, cfg): layers = [] in_channels = 3 for x in cfg: if x == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv_index = BayesianLayer.Conv2dGroupNJ(in_channels, x, kernel_size=3, padding=1, clip_var=0.04, cuda=True) layers += [conv_index, nn.BatchNorm2d(x), nn.ReLU(inplace=True)] self.kl_list.append(conv_index) in_channels = x layers += [nn.AvgPool2d(kernel_size=1, stride=1)] return nn.Sequential(*layers)

    def get_masks(self,thresholds): # import pdb # pdb.set_trace() weight_masks = [] mask = None layers = self.kl_list for i, (layer, threshold) in enumerate(zip(layers, thresholds)): # compute dropout mask if len(layer.weight_mu.shape) > 2: if mask is None: mask = [True]*layer.in_channels else: mask = np.copy(next_mask)

            log_alpha = layers[i].get_log_dropout_rates().cpu().data.numpy()
            next_mask = log_alpha <= thresholds[i]
    
            weight_mask = np.expand_dims(mask, axis=0) * np.expand_dims(next_mask, axis=1)
            weight_mask = weight_mask[:,:,None,None]
        else:
            if mask is None:
                log_alpha = layer.get_log_dropout_rates().cpu().data.numpy()
                mask = log_alpha <= threshold
            elif len(weight_mask.shape) > 2:
                temp = next_mask.repeat(layer.in_features/next_mask.shape[0])
                log_alpha = layer.get_log_dropout_rates().cpu().data.numpy()
                mask = log_alpha <= threshold
                #mask = mask | temp  ##Upper bound for number of weights at first fully connected layer
                mask = mask & temp   ##Lower bound for number of weights at fully connected layer
            else:
                mask = np.copy(next_mask)
    
            try:
                log_alpha = layers[i + 1].get_log_dropout_rates().cpu().data.numpy()
                next_mask = log_alpha <= thresholds[i + 1]
            except:
                # must be the last mask
                next_mask = np.ones(10)
    
            weight_mask = np.expand_dims(mask, axis=0) * np.expand_dims(next_mask, axis=1)
        weight_masks.append(weight_mask.astype(np.float))
    return weight_masks
    

    def model_kl_div(self): KLD = 0 for layer in self.kl_list: KLD += layer.layer_kl_div() return KLD `

    Does it cause by high variance? But I have tried to clip variance, it doesn't work...

    opened by kuanzi 2
  • Samples of compression for LeNet

    Samples of compression for LeNet

    Hi author

    Thanks for sharing the code. I am pretty interested in this work. When I am testing the compression LeNet, it raises "dimension not match" error. Could you share an example of compressing neural network with convolutional layers?

    opened by Lyken17 10
Owner
Karen Ullrich
Research scientist (s/h) at FAIR NY + collab. w/ Vector Institute. <3 Deep Learning + Information Theory. Previously, Machine Learning PhD at UoAmsterdam.
Karen Ullrich
Fader Networks: Manipulating Images by Sliding Attributes - NIPS 2017

FaderNetworks PyTorch implementation of Fader Networks (NIPS 2017). Fader Networks can generate different realistic versions of images by modifying at

Facebook Research 753 Dec 23, 2022
PyTorch implementation of NIPS 2017 paper Dynamic Routing Between Capsules

Dynamic Routing Between Capsules - PyTorch implementation PyTorch implementation of NIPS 2017 paper Dynamic Routing Between Capsules from Sara Sabour,

Adam Bielski 475 Dec 24, 2022
Fully reproducible, Dockerized, step-by-step, tutorial on how to mock a "real-time" Kafka data stream from a timestamped csv file. Detailed blog post published on Towards Data Science.

time-series-kafka-demo Mock stream producer for time series data using Kafka. I walk through this tutorial and others here on GitHub and on my Medium

Maria Patterson 26 Nov 15, 2022
This repository is related to an Arabic tutorial, within the tutorial we discuss the common data structure and algorithms and their worst and best case for each, then implement the code using Python.

Data Structure and Algorithms with Python This repository is related to the Arabic tutorial here, within the tutorial we discuss the common data struc

Mohamed Ayman 33 Dec 2, 2022
Simple-nft-tutorial - A simple tutorial on making nft/memecoins on algorand

nft/memecoin Tutorial on Algorand Let's make a simple NFT/memecoin on the Algora

null 2 Feb 5, 2022
The code release of paper 'Domain Generalization for Medical Imaging Classification with Linear-Dependency Regularization' NIPS 2020.

Domain Generalization for Medical Imaging Classification with Linear Dependency Regularization The code release of paper 'Domain Generalization for Me

Yufei Wang 56 Dec 28, 2022
Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning The predictive learning of spatiotemporal sequences aims to generate future

THUML: Machine Learning Group @ THSS 243 Dec 26, 2022
PyTorch implementation of the NIPS-17 paper "Poincaré Embeddings for Learning Hierarchical Representations"

Poincaré Embeddings for Learning Hierarchical Representations PyTorch implementation of Poincaré Embeddings for Learning Hierarchical Representations

Facebook Research 1.6k Dec 29, 2022
PyTorch implementation of spectral graph ConvNets, NIPS’16

Graph ConvNets in PyTorch October 15, 2017 Xavier Bresson http://www.ntu.edu.sg/home/xbresson https://github.com/xbresson https://twitter.com/xbresson

Xavier Bresson 287 Jan 4, 2023
PyTorch implementation of the Value Iteration Networks (VIN) (NIPS '16 best paper)

Value Iteration Networks in PyTorch Tamar, A., Wu, Y., Thomas, G., Levine, S., and Abbeel, P. Value Iteration Networks. Neural Information Processing

LEI TAI 75 Nov 24, 2022
Pytorch implementation of Value Iteration Networks (NIPS 2016 best paper)

VIN: Value Iteration Networks A quick thank you A few others have released amazing related work which helped inspire and improve my own implementation

Kent Sommer 297 Dec 26, 2022
PyTorch implementation of the NIPS-17 paper "Poincaré Embeddings for Learning Hierarchical Representations"

Poincaré Embeddings for Learning Hierarchical Representations PyTorch implementation of Poincaré Embeddings for Learning Hierarchical Representations

Facebook Research 1.6k Dec 25, 2022
Layout Analysis Evaluator for the ICDAR 2017 competition on Layout Analysis for Challenging Medieval Manuscripts

LayoutAnalysisEvaluator Layout Analysis Evaluator for: ICDAR 2019 Historical Document Reading Challenge on Large Structured Chinese Family Records ICD

null 17 Dec 8, 2022
Implement 'Single Shot Text Detector with Regional Attention, ICCV 2017 Spotlight'

SSTDNet Implement 'Single Shot Text Detector with Regional Attention, ICCV 2017 Spotlight' using pytorch. This code is work for general object detecti

HotaekHan 84 Jan 5, 2022
The PyTorch improved version of TPAMI 2017 paper: Face Alignment in Full Pose Range: A 3D Total Solution.

Face Alignment in Full Pose Range: A 3D Total Solution By Jianzhu Guo. [Updates] 2020.8.30: The pre-trained model and code of ECCV-20 are made public

Jianzhu Guo 3.4k Jan 2, 2023
An efficient PyTorch implementation of the winning entry of the 2017 VQA Challenge.

Bottom-Up and Top-Down Attention for Visual Question Answering An efficient PyTorch implementation of the winning entry of the 2017 VQA Challenge. The

Hengyuan Hu 731 Jan 3, 2023
Implementation of EMNLP 2017 Paper "Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog" using PyTorch and ParlAI

Language Emergence in Multi Agent Dialog Code for the Paper Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog Satwik Kottur, José M.

Karan Desai 105 Nov 25, 2022
PyTorch version of the paper 'Enhanced Deep Residual Networks for Single Image Super-Resolution' (CVPRW 2017)

About PyTorch 1.2.0 Now the master branch supports PyTorch 1.2.0 by default. Due to the serious version problem (especially torch.utils.data.dataloade

Sanghyun Son 2.1k Jan 1, 2023
Oriented Response Networks, in CVPR 2017

Oriented Response Networks [Home] [Project] [Paper] [Supp] [Poster] Torch Implementation The torch branch contains: the official torch implementation

ZhouYanzhao 217 Dec 12, 2022
A PyTorch implementation of the paper "Semantic Image Synthesis via Adversarial Learning" in ICCV 2017

Semantic Image Synthesis via Adversarial Learning This is a PyTorch implementation of the paper Semantic Image Synthesis via Adversarial Learning. Req

Seonghyeon Nam 146 Nov 25, 2022