STBP is a way to train SNN with datasets by Backward propagation.

Overview

Train SNN with STBP in fp32 and low bit(quantize)

Spiking neural network (SNN), compared with depth neural network (DNN), has faster processing speed, lower energy consumption and more biological interpretability, which is expected to approach Strong AI.

STBP is a way to train SNN with datasets by Backward propagation.Using this Repositories allows you to train SNNS with STBP and quantize SNNS with QAT to deploy to neuromorphological chips like Loihi and Tianjic.

Usage

Download via GitHub:

git clone https://github.com/ZLkanyo009/STBP-train-and-compression.git

example to define SNN_layers like ANN_layers

Convert layer to spatiotemporal layer:

conv = nn.Conv2d(...)
conv_s = tdLayer(conv)

Define LIF activation function just like Relu:

spike = LIFSpike()

In the forward function, replace the activation function of each layer with LIF activation function, and replace the calls such as conv() with conv_ s(), then SNN_layers definition is completed.Finally, we use Frequency Coding to decode SNN's output like out = torch.sum(x, dim=2) / steps

def forward(self, x):
    x = self.conv1_s(x)
    x = self.spike(x)
    x = self.pool1_s(x)
    x = self.spike(x)
    x = x.view(x.shape[0], -1, x.shape[4])
    x = self.fc1_s(x)
    x = self.spike(x)
    out = torch.sum(x, dim=2) / steps
    return out

If BN layer is required:

bn = nn.BatchNorm2d(...)
bn = tdBatchNorm(...)
conv_s = tdLayer(conv, bn)

Training Fp32 Model

# Start training fp32 model with: 
# model_name can be ResNet18, CifarNet, ...
python main.py ResNet18 --dataset CIFAR10

# training with DDP:
python -m torch.distributed.launch main.py ResNet18 --local_rank 0 --dataset CIFAR10 --p DDP

# You can manually config the training with: 
python main.py ResNet18 --resume --lr 0.01

Training Quantize Model

# Start training quantize model with: 
# model_name can be ResNet18, CifarNet, ...
python main.py ResNet18 --dataset CIFAR10 -q

# training with DDP:
python -m torch.distributed.launch main.py ResNet18 --local_rank 0 --dataset CIFAR10 -q --p DDP

# You can manually config the training with: 
python main.py ResNet18 -q --resume --bit 4 --lr 0.01

Accuracy

Model Acc.(fp32) Acc.(8 bit quantize)
MNISTNet 97.96% 97.57%
ResNet18 84.40% 84.23%

About STBP

You might also like...
[CVPR'21] MonoRUn: Monocular 3D Object Detection by Reconstruction and Uncertainty Propagation
[CVPR'21] MonoRUn: Monocular 3D Object Detection by Reconstruction and Uncertainty Propagation

MonoRUn MonoRUn: Monocular 3D Object Detection by Reconstruction and Uncertainty Propagation. CVPR 2021. [paper] Hansheng Chen, Yuyao Huang, Wei Tian*

meProp: Sparsified Back Propagation for Accelerated Deep Learning
meProp: Sparsified Back Propagation for Accelerated Deep Learning

meProp The codes were used for the paper meProp: Sparsified Back Propagation for Accelerated Deep Learning with Reduced Overfitting (ICML 2017) [pdf]

Implementation for our ICCV2021 paper: Internal Video Inpainting by Implicit Long-range Propagation
Implementation for our ICCV2021 paper: Internal Video Inpainting by Implicit Long-range Propagation

Implicit Internal Video Inpainting Implementation for our ICCV2021 paper: Internal Video Inpainting by Implicit Long-range Propagation paper | project

BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment

BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment

This folder contains the implementation of the multi-relational attribute propagation algorithm.

MrAP This folder contains the implementation of the multi-relational attribute propagation algorithm. It requires the package pytorch-scatter. Please

meProp: Sparsified Back Propagation for Accelerated Deep Learning (ICML 2017)
meProp: Sparsified Back Propagation for Accelerated Deep Learning (ICML 2017)

meProp The codes were used for the paper meProp: Sparsified Back Propagation for Accelerated Deep Learning with Reduced Overfitting (ICML 2017) [pdf]

This is the official implementation of the paper
This is the official implementation of the paper "Object Propagation via Inter-Frame Attentions for Temporally Stable Video Instance Segmentation".

[CVPRW 2021] - Object Propagation via Inter-Frame Attentions for Temporally Stable Video Instance Segmentation

Some tentative models that incorporate label propagation to graph neural networks for graph representation learning in nodes, links or graphs.

Some tentative models that incorporate label propagation to graph neural networks for graph representation learning in nodes, links or graphs.

[AAAI22] Reliable Propagation-Correction Modulation for Video Object Segmentation
[AAAI22] Reliable Propagation-Correction Modulation for Video Object Segmentation

Reliable Propagation-Correction Modulation for Video Object Segmentation (AAAI22) Preview version paper of this work is available at: https://arxiv.or

Comments
  • quantisation aware training flow

    quantisation aware training flow

    Hi!

    As far as I can tell, you are doing the following during training: Training loop (over batches): Forward pass. Backward pass Update weights Quantize weights

    This means that you always have the forward pass on quantized weights, and can calculate the backward pass and weight update as usual. I think that this is possibly problematic though, because small updates on the weights always get reset to the closest quantized value, and have no chance to accumulate over multiple passes to get to the next value. Is this the case? Or am I missing something?

    opened by widmerl 5
Owner
Ling Zhang
Ling Zhang
Custom TensorFlow2 implementations of forward and backward computation of soft-DTW algorithm in batch mode.

Batch Soft-DTW(Dynamic Time Warping) in TensorFlow2 including forward and backward computation Custom TensorFlow2 implementations of forward and backw

null 19 Aug 30, 2022
Implements Stacked-RNN in numpy and torch with manual forward and backward functions

Recurrent Neural Networks Implements simple recurrent network and a stacked recurrent network in numpy and torch respectively. Both flavours implement

Vishal R 1 Nov 16, 2021
An easy way to build PyTorch datasets. Modularly build datasets and automatically cache processed results

EasyDatas An easy way to build PyTorch datasets. Modularly build datasets and automatically cache processed results Installation pip install git+https

Ximing Yang 4 Dec 14, 2021
A python software that can help blind people find things like laptops, phones, etc the same way a guide dog guides a blind person in finding his way.

GuidEye A python software that can help blind people find things like laptops, phones, etc the same way a guide dog guides a blind person in finding h

Munal Jain 0 Aug 9, 2022
Deep Learning Datasets Maker is a QGIS plugin to make datasets creation easier for raster and vector data.

Deep Learning Dataset Maker Deep Learning Datasets Maker is a QGIS plugin to make datasets creation easier for raster and vector data. How to use Down

deepbands 25 Dec 15, 2022
Cl datasets - PyTorch image dataloaders and utility functions to load datasets for supervised continual learning

Continual learning datasets Introduction This repository contains PyTorch image

berjaoui 5 Aug 28, 2022
[CVPR 2021] Modular Interactive Video Object Segmentation: Interaction-to-Mask, Propagation and Difference-Aware Fusion

[CVPR 2021] Modular Interactive Video Object Segmentation: Interaction-to-Mask, Propagation and Difference-Aware Fusion

Rex Cheng 364 Jan 3, 2023
[CVPR 2021] MiVOS - Mask Propagation module. Reproduced STM (and better) with training code :star2:. Semi-supervised video object segmentation evaluation.

MiVOS (CVPR 2021) - Mask Propagation Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [arXiv] [Paper PDF] [Project Page] [Papers with Code] This repo impleme

Rex Cheng 106 Jan 3, 2023
MPLP: Metapath-Based Label Propagation for Heterogenous Graphs

MPLP: Metapath-Based Label Propagation for Heterogenous Graphs Results on MAG240M Here, we demonstrate the following performance on the MAG240M datase

Qiuying Peng 10 Jun 28, 2022