Pytorch Implementation of Spiking Neural Networks Calibration, ICML 2021



Pytorch Implementation of Spiking Neural Networks Calibration, ICML 2021

Feature Comparison of SNN calibration:

Features SNN Direct Training ANN-SNN Conversion SNN Calibration
Accuract (T<100​) High Low High
Scalability to ImageNet Tiny Large Large
Training Speed Slow Fast Fast
# Required Data Full-set
(1.2M For ImageNet)
~1000 ~1000
Inference Speed Fast Slow Fast


Pytorch 1.8

For ImageNet experiments, please be sure that you can initialize distributed environments

For CIFAR experiments, one GPU would suffice.

Pre-training ANN on CIFAR10&100

Train an ANN model with

python CIFAR/ --dataset CIFAR10 --arch VGG16 --dpath PATH/TO/DATA --usebn

Pre-trained results:

Dataset Model Random Seed Accuracy
CIFAR10 VGG16 1000 95.76
CIFAR10 ResNet-20 1000 95.68
CIFAR100 VGG16 1000 77.98
CIFAR100 ResNet-20 1000 76.52

SNN Calibration on CIFAR10&100

Calibrate an SNN with

python CIFAR/ --dataset CIFAR10 --arch VGG16 --T 16 --usebn --calib advanced --dpath PATH/TO/DATA

--T is the time step, --calib is the calibration method, please use none, light, advanced for experiments.

The calibration will run for 5 times, and return the mean accuracy as well as the standard deviation.

Example results:

Architecture Datset T Random Seed Calibration Mean Acc Std.
VGG16 CIFAR10 16 1000 None 64.52 4.12
VGG16 CIFAR10 16 1000 Light 93.30 0.08
VGG16 CIFAR10 16 1000 Advanced 93.65 0.25
ResNet-20 CIFAR10 16 1000 None 67.88 3.63
ResNet-20 CIFAR10 16 1000 Light 93.89 0.20
ResNet-20 CIFAR10 16 1000 Advanced 94.33 0.12
VGG16 CIFAR100 16 1000 None 2.69 0.76
VGG16 CIFAR100 16 1000 Light 65.26 0.99
VGG16 CIFAR100 16 1000 Advanced 70.91 0.65
ResNet-20 CIFAR100 16 1000 None 39.27 2.85
ResNet-20 CIFAR100 16 1000 Light 73.89 0.15
ResNet-20 CIFAR100 16 1000 Advanced 74.48 0.16

Pre-training ANN on ImageNet

To be updaed

You might also like...
An implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep Neural Networks in PyTorch.

Neural Attention Distillation This is an implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep

Official implementation of "SinIR: Efficient General Image Manipulation with Single Image Reconstruction" (ICML 2021)

SinIR (Official Implementation) Requirements To install requirements: pip install -r requirements.txt We used Python 3.7.4 and f-strings which are in

Implementation of Learning Gradient Fields for Molecular Conformation Generation (ICML 2021).
Implementation of Learning Gradient Fields for Molecular Conformation Generation (ICML 2021).

[PDF] | [Slides] The official implementation of Learning Gradient Fields for Molecular Conformation Generation (ICML 2021 Long talk) Installation Inst

Implementation of Self-supervised Graph-level Representation Learning with Local and Global Structure (ICML 2021).
Implementation of Self-supervised Graph-level Representation Learning with Local and Global Structure (ICML 2021).

Self-supervised Graph-level Representation Learning with Local and Global Structure Introduction This project is an implementation of ``Self-supervise

This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021.

Off-Belief Learning Introduction This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021. Environment Setup

PyTorch implementation of SCAFFOLD (Stochastic Controlled Averaging for Federated Learning, ICML 2020).

Scaffold-Federated-Learning PyTorch implementation of SCAFFOLD (Stochastic Controlled Averaging for Federated Learning, ICML 2020). Environment numpy=

TensorFlow code for the neural network presented in the paper:
TensorFlow code for the neural network presented in the paper: "Structural Language Models of Code" (ICML'2020)

SLM: Structural Language Models of Code This is an official implementation of the model described in: "Structural Language Models of Code" [PDF] To ap

Code for the ICML 2021 paper:
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

Code for the ICML 2021 paper:
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

  • SNN Calibration of Original ResNet20

    SNN Calibration of Original ResNet20

    Dear Yuhang,

    I have noticed that you are using a ResNet20 for CIFAR10 with 11.3 Million parameters. In the original ResNet publication of He et al [1] the definition of ResNet20 on CIFAR10 is given and results in 0.27 Million parameters. I know that it is somewhat "conventional" to use the implementation of ResNet20 you are using, the problem is that I am really interested in the one with the smaller number of parameters : P

    I have defined the "original" ResNet20 for CIFAR10 with 0.27 M parameters as shown below. I have added the file under models in your repository and run first the ANN training and then SNN calibration on it:

    python -m SNN_Calibration.CIFAR.main_train --dataset CIFAR10 --arch orgres20 --dpath 'datasets/CIFAR10/' --usebn
    python -m SNN_Calibration.CIFAR.main_calibration --dataset CIFAR10 --arch orgres20 --T 16 --usebn --calib advanced --dpath 'datasets/CIFAR10/'

    The ANN training is working well and results in 93.5% accuracy. But for some reason the SNN_Calibration doesn't work on the network below and results in 20% accuracy. Please help to get the SNN Calibration working on this : ) It would be much appreciated to understand the issue here.

    [1] He, K., Zhang, X., Ren, S., & Sun, J. (2015). Deep Residual Learning for Image Recognition. Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, 2016-Decem, 770–778.

    ResNet20 on CIFAR10 with the correct number of parameter (0.27M) as in the original publication [1].
    [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
    [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016.
    import torch
    import torch.nn as nn
    import math
    # @anna: I fixed the following relative imports
    from ...CIFAR.models.utils import AvgPoolConv, StraightThrough
    from ...CIFAR.models.spiking_layer import SpikeModel, SpikeModule, Union
    import torch.nn.functional as F
    from .resnet import SpikeBasicBlock
    def conv3x3(in_planes, out_planes, stride=1):
        " 3x3 convolution with padding"
        return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
    class BasicBlock(nn.Module):
        expansion = 1
        def __init__(self, inplanes, planes, stride=1, downsample=None):
            super(BasicBlock, self).__init__()
            self.conv1 = conv3x3(inplanes, planes, stride)
            self.bn1 = BN(planes)
            self.relu1 = ReLU(inplace=True)
            self.conv2 = conv3x3(planes, planes)
            self.bn2 = BN(planes)
            self.downsample = downsample
            self.stride = stride
            self.relu2 = ReLU(inplace=True)
        def forward(self, x):
            residual = x
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu1(out)
            out = self.conv2(out)
            out = self.bn2(out)
            if self.downsample is not None:
                residual = self.downsample(x)
            out += residual
            out = self.relu2(out)
            return out
    class Org_ResNet_Cifar_Modified(nn.Module):
        def __init__(self, block, layers, num_classes=10, use_bn=True):
            super(Org_ResNet_Cifar_Modified, self).__init__()
            global BN
            BN = nn.BatchNorm2d if use_bn else StraightThrough
            global ReLU
            ReLU = nn.ReLU
            self.inplanes = 64
            self.conv1 = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
                nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False),
            self.layer1 = self._make_layer(block, 16, layers[0], stride=1)
            self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
            self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
            #self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
            self.avgpool = AvgPoolConv(kernel_size=4, stride=1, input_channel=64)
            self.fc_save = nn.Linear(64, num_classes)
            #self.fc = nn.Linear(64, num_classes)
            for m in self.modules():
                if isinstance(m, nn.Conv2d) and not isinstance(m, AvgPoolConv):
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
          , math.sqrt(2. / n))
                elif isinstance(m, nn.BatchNorm2d):
                elif isinstance(m, nn.Linear):
                    n = m.weight.size(1)
          , 1.0 / float(n))
        def _make_layer(self, block, planes, blocks, stride=1):
            downsample = None
            if stride != 1 or self.inplanes != planes * block.expansion:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                    BN(planes * block.expansion)
            layers = []
            layers.append(block(self.inplanes, planes, stride, downsample))
            self.inplanes = planes * block.expansion
            for _ in range(1, blocks):
                layers.append(block(self.inplanes, planes))
            return nn.Sequential(*layers)
        def forward(self, x):
            x = self.conv1(x)
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            #x = F.avg_pool2d(x, x.size()[3])
            #x = self.layer4(x)
            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self.fc_save(x)
            return x
    def org_resnet20(**kwargs):
        model = Org_ResNet_Cifar_Modified(BasicBlock, [3, 3, 3], **kwargs)
        return model
    res_specials = {BasicBlock: SpikeBasicBlock}
    opened by annahambi 6
  • Pre-training ANN model

    Pre-training ANN model

    Hello, I have tested the ANN model of VGG16 without usebn provided by you on ImageNet dataset, but the accuracy is only 55.074%. Could you please update the pre-training ANN model?

    opened by lanyx7 5
  • SNN simulation length parameter seems unused during calibration

    SNN simulation length parameter seems unused during calibration

    Dear @yhhhli

    This is an awesome repository and thank you so much for publishing it.

    My question/ issue is regarding the parameter --T for the SNN calibration: In the code of the CIFAR SNN calibration it seems that the SNN simulation length is hardcoded to sim_length = 32 (line 71) and the args.T parameter is unused.

    Can you comment on this? Because in your table of results you mention T=16. How do you achieve this to be used during testing (not necessarily calibration itself)?

    opened by annahambi 1
  • Training Code

    Training Code

    Hello, when I run the program, an error occurred "AttributeError: Can 't pickle local object' SubPolicy. Just set the < locals >. < lambda > '", don't know if you Can help me to solve it? Thank you very much!

    opened by JominWink 10
Yuhang Li
Research Intern at @SenseTime Group Limited
Yuhang Li
PyTorch implementation of Spiking Neural Networks trained on surrogate gradient & BPTT using snntorch.

snn-localization repo PyTorch implementation of Spiking Neural Networks trained on surrogate gradient & BPTT using snntorch. Install Dependencies Orig

Sami BARCHID 1 Jan 6, 2022
Deep learning for spiking neural networks

A deep learning library for spiking neural networks. Norse aims to exploit the advantages of bio-inspired neural components, which are sparse and even

Electronic Vision(s) Group — BrainScaleS Neuromorphic Hardware 59 Nov 28, 2022
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

SynSense 21 Dec 14, 2022
Deep and online learning with spiking neural networks in Python

Introduction The brain is the perfect place to look for inspiration to develop more efficient neural networks. One of the main differences with modern

Jason Eshraghian 447 Jan 3, 2023
S2-BNN: Bridging the Gap Between Self-Supervised Real and 1-bit Neural Networks via Guided Distribution Calibration (CVPR 2021)

S2-BNN (Self-supervised Binary Neural Networks Using Distillation Loss) This is the official pytorch implementation of our paper: "S2-BNN: Bridging th

Zhiqiang Shen 52 Dec 24, 2022
Source code of NeurIPS 2021 Paper ''Be Confident! Towards Trustworthy Graph Neural Networks via Confidence Calibration''

CaGCN This repo is for source code of NeurIPS 2021 paper "Be Confident! Towards Trustworthy Graph Neural Networks via Confidence Calibration". Paper L

null 6 Dec 19, 2022
Spiking Neural Network for Computer Vision using SpikingJelly framework and Pytorch-Lightning

Spiking Neural Network for Computer Vision using SpikingJelly framework and Pytorch-Lightning

Sami BARCHID 2 Oct 20, 2022
Code for ICML 2021 paper: How could Neural Networks understand Programs?

OSCAR This repository contains the source code of our ICML 2021 paper How could Neural Networks understand Programs?. Environment Run following comman

Dinglan Peng 115 Dec 17, 2022
Neural-Pull: Learning Signed Distance Functions from Point Clouds by Learning to Pull Space onto Surfaces(ICML 2021)

Neural-Pull: Learning Signed Distance Functions from Point Clouds by Learning to Pull Space onto Surfaces(ICML 2021) This repository contains the code

null 149 Dec 15, 2022
[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

MDCA Calibration This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved

MDCA Calibration 21 Dec 22, 2022