Deep and online learning with spiking neural networks in Python

Overview

Introduction

Documentation Status

https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_scaled.png?raw=true

The brain is the perfect place to look for inspiration to develop more efficient neural networks. One of the main differences with modern deep learning is that the brain encodes information in spikes rather than continuous activations. snnTorch is a Python package for performing gradient-based learning with spiking neural networks. It extends the capabilities of PyTorch, taking advantage of its GPU accelerated tensor computation and applying it to networks of spiking neurons. Pre-designed spiking neuron models are seamlessly integrated within the PyTorch framework and can be treated as recurrent activation units.

https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/spike_excite_alpha_ps2.gif?raw=true

snnTorch Structure

snnTorch contains the following components:

Component Description
snntorch a spiking neuron library like torch.nn, deeply integrated with autograd
snntorch.backprop variations of backpropagation commonly used with SNNs
snntorch.functional common arithmetic operations on spikes, e.g., loss, regularization etc.
snntorch.spikegen a library for spike generation and data conversion
snntorch.spikeplot visualization tools for spike-based data using matplotlib and celluloid
snntorch.spikevision contains popular neuromorphic datasets
snntorch.surrogate optional surrogate gradient functions
snntorch.utils dataset utility functions

snnTorch is designed to be intuitively used with PyTorch, as though each spiking neuron were simply another activation in a sequence of layers. It is therefore agnostic to fully-connected layers, convolutional layers, residual connections, etc.

At present, the neuron models are represented by recursive functions which removes the need to store membrane potential traces for all neurons in a system in order to calculate the gradient. The lean requirements of snnTorch enable small and large networks to be viably trained on CPU, where needed. Provided that the network models and tensors are loaded onto CUDA, snnTorch takes advantage of GPU acceleration in the same way as PyTorch.

Citation

If you find snnTorch useful in your work, please consider citing the following source:

Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu “Training Spiking Neural Networks Using Lessons From Deep Learning”. arXiv preprint arXiv:2109.12894, September 2021.

@article{eshraghian2021training,
title={Training spiking neural networks using lessons from deep learning},
author={Eshraghian, Jason K and Ward, Max and Neftci, Emre and Wang, Xinxin
and Lenz, Gregor and Dwivedi, Girish and Bennamoun, Mohammed and Jeong, Doo Seok
and Lu, Wei D},
journal={arXiv preprint arXiv:1906.09395},
year={2021}
}

Requirements

The following packages need to be installed to use snnTorch:

  • torch >= 1.1.0
  • numpy >= 1.17
  • pandas
  • matplotlib
  • math

They are automatically installed if snnTorch is installed using the pip command. Ensure the correct version of torch is installed for your system to enable CUDA compatibility.

Installation

Run the following to install:

$ python
$ pip install snntorch

To install snnTorch from source instead:

$ git clone https://github.com/jeshraghian/snnTorch
$ cd snnTorch
$ python setup.py install

API & Examples

A complete API is available here. Examples, tutorials and Colab notebooks are provided.

Quickstart

Here are a few ways you can get started with snnTorch:

Open In Colab

For a quick example to run snnTorch, see the following snippet, or test the quickstart notebook above:

import torch, torch.nn as nn
import snntorch as snn
from snntorch import surrogate

num_steps = 25 # number of time steps
batch_size = 1
beta = 0.5  # neuron decay rate
spike_grad = surrogate.fast_sigmoid()

net = nn.Sequential(
      nn.Conv2d(1, 8, 5),
      nn.MaxPool2d(2),
      snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad),
      nn.Conv2d(8, 16, 5),
      nn.MaxPool2d(2),
      snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad),
      nn.Flatten(),
      nn.Linear(16 * 4 * 4, 10),
      snn.Leaky(beta=beta, init_hidden=True, spike_grad=spike_grad, output=True)
      )

# random input data
data_in = torch.rand(num_steps, batch_size, 1, 28, 28)

spike_recording = []

for step in range(num_steps):
    spike, state = net(data_in[step])
    spike_recording.append(spike)

If you're feeling lazy and want the training process to be taken care of:

import snntorch.functional as SF
from snntorch import backprop

# correct class should fire 80% of the time
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999))

# train for one epoch using the backprop through time algorithm
# assume train_loader is a DataLoader with time-varying input
avg_loss = backprop.BPTT(net, train_loader, num_steps=num_steps,
                        optimizer=optimizer, criterion=loss_fn)

A Deep Dive into SNNs

If you wish to learn all the fundamentals of training spiking neural networks, from neuron models, to the neural code, up to backpropagation, the snnTorch tutorial series is a great place to begin. It consists of interactive notebooks with complete explanations that can get you up to speed.

Tutorial Title Colab Link
Tutorial 1 Spike Encoding with snnTorch Open In Colab
Tutorial 2 The Leaky Integrate and Fire Neuron Open In Colab
Tutorial 3 A Feedforward Spiking Neural Network Open In Colab
Tutorial 4 2nd Order Spiking Neuron Models (Optional) Open In Colab
Tutorial 5 Training Spiking Neural Networks with snnTorch Open In Colab
Tutorial 6 Surrogate Gradient Descent in a Convolutional SNN Open In Colab
Tutorial 7 Neuromorphic Datasets with Tonic + snnTorch Open In Colab

Contributing

If you're ready to contribute to snnTorch, instructions to do so can be found here.

Acknowledgments

snnTorch was initially developed by Jason K. Eshraghian in the Lu Group (University of Michigan).

Additional contributions were made by Xinxin Wang, Vincent Sun, and Emre Neftci.

Several features in snnTorch were inspired by the work of Friedemann Zenke, Emre Neftci, Doo Seok Jeong, Sumit Bam Shrestha and Garrick Orchard.

License & Copyright

snnTorch is licensed under the GNU General Public License v3.0: https://www.gnu.org/licenses/gpl-3.0.en.html.

Comments
  • Examples of regression?

    Examples of regression?

    I was wondering if anyone had used snnTorch for regression, and perhaps how you set your networks up. Just looking for simple, general examples! MSELoss would likely be the type of loss used as I see it.

    opened by shilpakancharla 29
  • snntorch-ipu crashing

    snntorch-ipu crashing

    • snntorch version: 0.5.3
    • snntorch-ipu version: 0.5.18
    • PopTorch version: 2.6.0
    • PyTorch version: 1.10.0
    • Python version: 3.8.10
    • Operating System: Ubuntu 20.04

    Description

    I've been trying to train a model in an IPU environment using PopTorch and snntorch-ipu. Unfortunately, I always get a crash. It is unclear to me what exactly is going on, so hopefully someone knows.

    What I Did

    If I try to train my model with only snntorch-ipu installed, as recommended, I will always get an error message when importing/working with surrogates about "Missing Straight Through Estimator Custom Operation file".

    /notebooks/dvsclf/network/net.py in <module>
          1 import torch
    ----> 2 from snntorch import surrogate
          3 from snntorch import utils
          4 import torch.nn as nn
          5 import numpy as np
    
    /usr/local/lib/python3.8/dist-packages/snntorch/surrogate.py in <module>
         26 
         27 
    ---> 28 class StraightThroughEstimator:
         29     """
         30     Straight Through Estimator.
    
    /usr/local/lib/python3.8/dist-packages/snntorch/surrogate.py in StraightThroughEstimator()
         53         print("Missing Straight Through Estimator Custom Operation file!")
         54         print(so_path_ste)
    ---> 55         exit(1)
         56     ctypes.cdll.LoadLibrary(so_path_ste)
         57 
    
    NameError: name 'exit' is not defined
    

    If I install snntorch (with or without snntorch-ipu beside it), I will not get the above error. Instead, something in PopTorch throws an error when the model is being trained.

    ---------------------------------------------------------------------------
    RuntimeError                              Traceback (most recent call last)
    <ipython-input-4-9ee252bfabe9> in <module>
          2     # Performs forward pass, loss function evaluation,
          3     # backward pass and weight update in one go on the device.
    ----> 4     _, loss = poptorch_model(batch, target)
    
    [....]
    
    /notebooks/dvsclf/network/snn.py in forward(self, x)
         20 
         21         x = self.conv(x)
    ---> 22         x = self.lif(x)
         23         return x
         24 
    
    /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
       1118             input = bw_hook.setup_input_hook(input)
       1119 
    -> 1120         result = forward_call(*input, **kwargs)
       1121         if _global_forward_hooks or self._forward_hooks:
       1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
    
    /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
       1088                 recording_scopes = False
       1089         try:
    -> 1090             result = self.forward(*input, **kwargs)
       1091         finally:
       1092             if recording_scopes:
    
    /usr/local/lib/python3.8/dist-packages/snntorch/_neurons/leaky.py in forward(self, input_, mem)
        159         if self.init_hidden:
        160             self._leaky_forward_cases(mem)
    --> 161             self.reset = self.mem_reset(self.mem)
        162             self.mem = self.state_fn(input_)
        163 
    
    /usr/local/lib/python3.8/dist-packages/snntorch/_neurons/neurons.py in mem_reset(self, mem)
         86         """Generates detached reset signal if mem > threshold.
         87         Returns reset."""
    ---> 88         mem_shift = mem - self.threshold
         89         reset = self.spike_grad(mem_shift).clone().detach()
         90 
    
    /usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in __torch_function__(cls, func, types, args, kwargs)
        279                     if kwargs is None:
        280                         kwargs = {}
    --> 281                     return super().__torch_function__(func, types, args,
        282                                                       kwargs)
        283 
    
    RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
    Tensor:
    (1,1,.,.) = 
     Columns 1 to 9  0.4529  0.4529  0.4529  0.4529  0.4529  0.4529  0.4529  0.4529  0.4529 ....
    

    With self.lif = snn.Leaky(beta=0.95, spike_grad=surrogate.fast_sigmoid(), init_hidden=True, learn_beta=True, learn_threshold=True)

    opened by RoelMK 11
  • Networks Learns Nothing!

    Networks Learns Nothing!

    • snntorch version: 0.2.8
    • Python version: 3.7.4
    • Operating System: Windows 10

    Description

    Hi,

    I tried to use the snnTorch to do exactly as you are doing in Tutorial 3 (without spike_grad), and also in the upcoming Tutorial 4 (applied spike_grad to stein neuron) for Spiking CNNs. Moreover, I also converted my dataset from static to spike version as in Tutorial 4 using rate encoding.

    However, usually, I found that the network learns nothing and the training loss goes down and up slightly, but the output numbers and accuracy are completely unintuitive.

    For example, I used a Conventional CNN on my Dataset and I got 85% accuracy, and I used it with the same architecture and hyperparameters and nothing is learned.

    I will be glad if you can have some direct suggestions about what may be the problem and why the network is learning nothing in the Spiking domain?

    Looking forward to your response

    opened by Dola47 9
  • Minor Update to Dependencies list in README.rst

    Minor Update to Dependencies list in README.rst

    • snntorch version: 0.5.3
    • Python version: 3.10.4
    • Operating System: Windows 10

    Description

    I think that ffmpeg is missing from the dependencies listed in README.rst.

    What I Did

    Using conda, I installed the dependencies and then received the following error on the video conversion step of the first tutorial:

    RuntimeError: Requested MovieWriter (ffmpeg) not available
    

    This was easily sorted with: conda install -c conda-forge ffmpeg.

    opened by katywarr 6
  • latency() got an unexpected keyword argument 'num_outputs' and latency() got multiple values for argument 'num_steps'

    latency() got an unexpected keyword argument 'num_outputs' and latency() got multiple values for argument 'num_steps'

    • snntorch version: 0.2.11
    • Python version: 3.7.10

    Description

    In the latest updates, I see that the num_output has been removed from the spikegen. No clue why!

    Moreover, also when I just decide about removing the num_output parameter. I get another error from the num_steps when I set it to any int value?

    BTW, the same exists if I tried to do rate encoding instead of latency encoding.

    What I Did

    Here: you will find a quick ipynb file that shows the errors.

    opened by Dola47 5
  • Neurons can fire multiple time steps in a row.

    Neurons can fire multiple time steps in a row.

    It is possible to have neurons firing continuously using the default reset mechanism or if using RLeaky and reset to zero. The latter is due to only resetting the input but not the recurrent connections. This is undesirable behavior as it allows the neurons to essentially not be spiking neurons given the right weight values.

    opened by EvilxFish 4
  • enhance(tutorial6): net definition and link

    enhance(tutorial6): net definition and link

    • make the Net class foward method single step
    • insert link to the tutorial 5 where cited

    ### Motivation There are several points in this very useful tutorial for which I propose this ameliorative change. For clarity I will call "v1" the current version of the Net definition and "v2" the version that is proposed in this PR.

    • An error launching "Run all". This is a shape incompatibility error in the loss calculation, because in v1 there is an internal loop in .forward, which thus adds an extra axes. Now it runs all cells with no errors.
    • .forward in v1 is apparently inconsistent with .forward in v2, being the former with an inner loop (so multi-step), while the latter without it (so single-step).
    • (am I overthinking?) Whether there was an intention behind defining the network with an inner loop to define a smaller dt, characteristic of an inner frequency higher than the outer frequency (receptors) is not very clear and it might be beneficial to make it explicit.

    It is clear that the tutorial 5 example was being cited, but in this context it comes across as a bit inconsistent. If I have misunderstood, please tell me how I should best understand the tutorial.

    Thank you for the effort of creating this project :)

    opened by gianfa 4
  • Detach and Reset Spikes in RLeaky

    Detach and Reset Spikes in RLeaky

    I noticed that for the Rleaky neurons the spike acts as an internal state but was not reset nor detached. I believe the spikes should be reset and detached similar to the membrane potential in reset_hidden and detach_hidden, respectively. This merge request adds support for resetting and detaching of the spikes.

    opened by manuelbre 4
  • Issue with inputing custom weights for Rate based SNN

    Issue with inputing custom weights for Rate based SNN

    • snntorch version: 0.4.4
    • Python version: 3.9.6
    • Operating System: Ubuntu 20.04.3 LTS

    Description

    Hi Jason, First of all, I appreciate your wonderful effort in developing this package and a detailed documentation. I have recently started using snntorch for rate based SNN coding. Although I am getting good performance for purely software based run, I am facing issues with inputing custom weights extracted from a synaptic device. My accuracy is getting stuck at around 10% which is the same as the untrained accuracy.

    What I Did

    I used a custom function to input the weights from a text file as shown in the screenshot. Please let me know how to solve this issue. NB: I am pretty new to programming. so pls excuse me if my code is too cumbersome :) Capture ` Here is the full file and the text file for data input rate_SNN_dev_weights.zip

    Thanks, Kannan

    opened by kannanum 4
  • RSynaptic Neuron Model NOT in snnTorch

    RSynaptic Neuron Model NOT in snnTorch

    Hello, I am trying to use the RSynaptic neuron model but even though it is documented in the snnTorch website, the class RSynaptic(LIF) is not included in the snntroch module. This this the error I get when I try to use it:

    AttributeError: module 'snntorch' has no attribute 'RSynaptic'

    opened by msbouanane 3
  • Tonic example

    Tonic example

    Added example notebook that shows how to download data, ready to feed to network. Tonic does support batching for tensor representations, if you want to do that please let me know and I'll add it.

    opened by biphasic 3
  • TBPTT mode gives error about K parameter on time varying data

    TBPTT mode gives error about K parameter on time varying data

    • snntorch version: 0.5.3
    • Python version: 3.8
    • Operating System: windows

    when i try to use TBPTT or RTRL from backprop on time varying signal, it gives me this error:

    Java Printing.pdf

        if K_flag is False:
    UnboundLocalError: local variable 'K_flag' referenced before assignment
    
    opened by alisam1992 0
  • Add power profiling capabilities

    Add power profiling capabilities

    This seems to be a super popular feature request. Making accurate estimates seems near impossible, but we can probably generate an order of magnitude guess here.

    The user would construct a model, pass data in, and the power profiling function returns the number of Synaptic operations in the forward-pass (this could be averaged across batches).

    Each synaptic op would be scaled by the energy cost for all selected devices; e.g., various GPUs & neuromorphic hardware. The same number would be given for non-spiking networks too. This could be achieved by just removing the spiking modules.

    SpikingKeras has a similar function that does it really nicely. However, it overstates the improvement given with spikes because it does not account for overhead (i.e., moving data to/from memory, or between multiple chips).

    Including an argument that factors in overhead would by tricky, but useful. The model would be parsed for number of neurons/synapses, and if either exceeds the bandwidth of a single chip, then we need to estimate how frequently data needs to be moved between chips & add that to the overall energy consumption.

    A lot of coarse estimates would be made, but I think it could be helpful.

    enhancement 
    opened by jeshraghian 0
  • snntorch multi GPU training issue

    snntorch multi GPU training issue

    • snntorch version: snntorch: 0.5.3
    • Python version: 3.9
    • Operating System: linux
    • nvidia-smi
    Every 0.5s: nvidia-smi                                                                                                                                    neuro: Fri Dec  2 11:16:53 2022
    
    Fri Dec  2 11:16:53 2022
    +-----------------------------------------------------------------------------+
    | NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.4     |
    |-------------------------------+----------------------+----------------------+
    | GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
    | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
    |                               |                      |               MIG M. |
    |===============================+======================+======================|
    |   0  NVIDIA RTX A6000    On   | 00000000:1B:00.0 Off |                  Off |
    | 30%   30C    P8    29W / 300W |      1MiB / 48682MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   1  NVIDIA RTX A6000    On   | 00000000:1C:00.0 Off |                  Off |
    | 30%   27C    P8    22W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   2  NVIDIA RTX A6000    On   | 00000000:1D:00.0 Off |                  Off |
    | 30%   32C    P8    23W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   3  NVIDIA RTX A6000    On   | 00000000:1E:00.0 Off |                  Off |
    | 30%   31C    P8    23W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   4  NVIDIA RTX A6000    On   | 00000000:3D:00.0 Off |                  Off |
    | 30%   27C    P8    22W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   5  NVIDIA RTX A6000    On   | 00000000:3F:00.0 Off |                  Off |
    | 30%   29C    P8    23W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   6  NVIDIA RTX A6000    On   | 00000000:40:00.0 Off |                  Off |
    | 30%   27C    P8    22W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   7  NVIDIA RTX A6000    On   | 00000000:41:00.0 Off |                  Off |
    | 30%   30C    P8    22W / 300W |      1MiB / 48685MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    
    +-----------------------------------------------------------------------------+
    | Processes:                                                                  |
    |  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
    |        ID   ID                                                   Usage      |
    |=============================================================================|
    |  No running processes found                                                 |
    +-----------------------------------------------------------------------------+
    
    

    Description

    I'm trying to train NMNIST with snntorch using multi GPU. since snntorch is based on torch package, I thought data parrallel from torch nn should work. here's whole code.

    import torch
    import torchvision.datasets as dsets
    import torchvision.transforms as transforms
    import torch.nn.init
    import os
    import torch.nn as nn
    import time
    import matplotlib.pyplot as plt
    import tonic.transforms as transforms
    import tonic
    import numpy as np
    import snntorch as snn
    from snntorch import surrogate
    from snntorch import functional as SF
    from snntorch import spikeplot as splt
    from snntorch import utils
    import torch.nn as nn
    import os
    from torch.utils.data import DataLoader, random_split
    import torch
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    sensor_size = tonic.datasets.NMNIST.sensor_size
    
    # Denoise removes isolated, one-off events
    # time_window
    frame_transform = transforms.ToFrame(sensor_size=sensor_size, time_window=1)
    
    
    frame_transform = transforms.Compose([transforms.Denoise(filter_time=10000),
                                          transforms.ToFrame(sensor_size=sensor_size,
                                                             time_window=50000)
                                         ])
    
    trainset = tonic.datasets.NMNIST(save_to='/home/hubo1024/PycharmProjects/snntorch/data/NMNIST', transform=frame_transform, train=True)
    testset = tonic.datasets.NMNIST(save_to='./home/hubo1024/PycharmProjects/snntorch/data/NMNIST', transform=frame_transform, train=False)
    
    # seed fix
    torch.manual_seed(777)
    
    # seed fix if gpu is available
    if device == 'cuda':
        torch.cuda.manual_seed_all(777)
    
    #batch_size = 100
    
    batch_size = 32
    dataset_size = len(trainset)
    train_size = int(dataset_size * 0.9)
    validation_size = int(dataset_size * 0.1)
    
    
    trainset, valset = random_split(trainset, [train_size, validation_size])
    print(len(valset))
    print(len(trainset))
    trainloader = DataLoader(trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(), shuffle=True)
    valloader = DataLoader(valset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(), shuffle=True)
    testloader = DataLoader(testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors())
    
    
    spike_grad = surrogate.fast_sigmoid(slope=75)
    beta = 0.5
    
    class CNN(torch.nn.Module):
    
        def __init__(self):
            super(CNN, self).__init__()
            self.keep_prob = 0.5
            self.layer1 = torch.nn.Sequential(
                nn.Conv2d(2, 12, 5),
                nn.MaxPool2d(2),
                snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
            )
    
            self.layer2 = torch.nn.Sequential(
                nn.Conv2d(12, 32, 5),
                nn.MaxPool2d(2),
                snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
            )
    
            self.layer4 = torch.nn.Sequential(
                nn.Flatten(),
                nn.Linear(32 * 5 * 5, 10),
                snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
            )
    
        def forward(self, data):
            spk_rec = []
            layer1_rec = []
            layer2_rec = []
            utils.reset(self.layer1)  # resets hidden states for all LIF neurons in net
            utils.reset(self.layer2)
            utils.reset(self.layer4)
    
            for step in range(data.size(1)):  # data.size(0) = number of time steps
                input_torch = data[:, step, :, :, :]
                input_torch = input_torch.cuda()
                #print(input_torch)
                out = self.layer1(input_torch)
                #out1 = out
    
                out = self.layer2(out)
                #out2 = out
                out, mem = self.layer4(out)
                #out = self.layer4(out)
    
                spk_rec.append(out)
    
                #layer1_rec.append(out1)
                #layer2_rec.append(out2)
    
            return torch.stack(spk_rec)#, torch.stack(layer1_rec), torch.stack(layer2_rec)
    
    
    model = CNN().to(device)
    device_ids = [0, 1] #your GPU index
    model = torch.nn.DataParallel(model, device_ids=device_ids)
    #model = nn.DataParallel(model).to(device)
    optimizer = torch.optim.NAdam(model.parameters(), lr=0.005,betas=(0.9, 0.999))
    loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)
    #model = nn.DataParallel(model)
    
    total_batch = len(trainloader)
    print('총 배치의 수 : {}'.format(total_batch))
    loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)
    num_epochs = 15
    loss_hist = []
    acc_hist = []
    v_acc_hist = []
    t_spk_rec_sum = []
    start = time.time()
    val_cnt = 0
    v_acc_sum= 0
    avg_loss = 0
    index = 0
    #################################################
    
    
    for epoch in range(num_epochs):
        torch.save(model.state_dict(), '/home/hubo1024/PycharmProjects/snntorch/model_pt/Radam_15epoch-50000.pt')
        for i, (data, targets) in enumerate(iter(trainloader)):
            data = data.cuda()
            targets = targets.cuda()
            model.train()
    
            spk_rec = model(data)
    
            #print(spk_rec.shape)
            loss_val = loss_fn(spk_rec, targets)
            avg_loss += loss_val.item()
            optimizer.zero_grad()
    
            loss_val.backward()
    
            optimizer.step()
    
            # Store loss history for future plotting
            loss_hist.append(loss_val.item())
            val_cnt = val_cnt+1
            #del loss_val
    
    
            if val_cnt == len(trainloader)/2-1:
                val_cnt=0
    
                for ii, (v_data, v_targets) in enumerate(iter(valloader)):
                    v_data = v_data.to(device)
                    v_targets = v_targets.to(device)
    
                    v_spk_rec = model(v_data)
                    #
                    # print(t_spk_rec.shape)
                    v_acc = SF.accuracy_rate(v_spk_rec, v_targets)
                    del v_spk_rec
                    if ii == 0:
                        v_acc_sum = v_acc
                        cnt = 1
    
                    else:
                        v_acc_sum += v_acc
                        cnt += 1
                    #del v_acc
    
    
                plt.plot(acc_hist)
                plt.plot(v_acc_hist)
                plt.legend(['train accuracy', 'validation accuracy'])
                plt.title("Train, Validation Accuracy-Radam 15epoch-50000")
                plt.xlabel("Iteration")
                plt.ylabel("Accuracy")
                # plt.show()
                plt.savefig('Radam_15epoch-50000.png')
                plt.clf()
                v_acc_sum = v_acc_sum/cnt
    
    
                # avg_loss = avg_loss / (len(trainloader) / 2)
                # print('average loss while half epoch', avg_loss)
                # if avg_loss <= 0.5:
                #     index = 1
                #     break
                # else:
                #     avg_loss = 0
                #     index = 0
    
            print('Radam-15epoch-50000')
            print("time :", time.time() - start,"sec")
            print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")
    
            acc = SF.accuracy_rate(spk_rec, targets)
            acc_hist.append(acc)
            v_acc_hist.append(v_acc_sum)
            print(f"Train Accuracy: {acc * 100:.2f}%")
            print(f"Validation Accuracy: {v_acc_sum * 100:.2f}%\n")
    
        #     if index == 1:
        #         break
        # if index == 1:
        #     break
    # 학습을 진행하지 않을 것이므로 torch.no_grad()
    '''
    with torch.no_grad():
        X_test = mnist_test.test_data.view(len(mnist_test), 1, 28, 28).float().to(device)
        Y_test = mnist_test.test_labels.to(device)
    
        prediction = model(X_test)
        correct_prediction = torch.argmax(prediction, 1) == Y_test
        accuracy = correct_prediction.float().mean()
        print('Accuracy:', accuracy.item())
    '''
    
    

    and here's error

    (snn_torch) hubo1024@neuro:~/PycharmProjects/snntorch$ CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 python gpu_6_run.py
    6000
    54000
    총 배치의 수 : 13500
    Traceback (most recent call last):
      File "/home/hubo1024/PycharmProjects/snntorch/gpu_6_run.py", line 146, in <module>
        spk_rec = model(data)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
        outputs = self.parallel_apply(replicas, inputs, kwargs)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
        output.reraise()
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/_utils.py", line 461, in reraise
        raise exception
    RuntimeError: Caught RuntimeError in replica 0 on device 0.
    Original Traceback (most recent call last):
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
        output = module(*input, **kwargs)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/hubo1024/PycharmProjects/snntorch/gpu_6_run.py", line 102, in forward
        out = self.layerconv1(input_torch)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
        input = module(input)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 162, in forward
        self.mem = self.state_fn(input_)
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 201, in _build_state_function_hidden
        self._base_state_function_hidden(input_) - self.reset * self.threshold
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/snntorch/_neurons/leaky.py", line 195, in _base_state_function_hidden
        base_fn = self.beta.clamp(0, 1) * self.mem + input_
      File "/home/hubo1024/anaconda3/envs/snn_torch/lib/python3.9/site-packages/torch/_tensor.py", line 1121, in __torch_function__
        ret = func(*args, **kwargs)
    RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
    
    

    I rerun this code after removing snn.Leaky layer in CNN and it worked fine. (of course the cost doesn't converge and accuracy was 0% but still it runs) So I assume that the reason of this error is snn.Leaky layer. I think changing

    opened by rkdgmlqja 5
  • Clean .flake8

    Clean .flake8

    The .flake8config currently excludes lots of hints and errors. The code-base should be cleaned with the standard Flake8 config.

    I would be glad to help out. Reporting as a reminder.

    opened by ahenkes1 1
  • Add class imbalance weighting to loss functions

    Add class imbalance weighting to loss functions

    Apply on/off target weighting to snntorch.functional losses in the same way the PyTorch enables weighting.

    Cross Entropy-based losses should be straightforward; Mean Square Error Losses are less trivial.

    enhancement 
    opened by jeshraghian 0
Releases(v0.5.2)
  • v0.5.2(Aug 4, 2022)

    What's Changed

    • leaky and rleaky state function substract function fix by @pengzhouzp in https://github.com/jeshraghian/snntorch/pull/95
    • Detach and Reset Spikes in RLeaky by @manuelbre in https://github.com/jeshraghian/snntorch/pull/108
    • Integrate ATan Surrogate function. by @ridgerchu in https://github.com/jeshraghian/snntorch/pull/111
    • bptt bug may trigger device inconsistency. by @ridgerchu in https://github.com/jeshraghian/snntorch/pull/115
    • Add a new feature 'probe' by @ridgerchu in https://github.com/jeshraghian/snntorch/pull/117

    New Contributors

    • @pengzhouzp made their first contribution in https://github.com/jeshraghian/snntorch/pull/95
    • @manuelbre made their first contribution in https://github.com/jeshraghian/snntorch/pull/108
    • @ridgerchu made their first contribution in https://github.com/jeshraghian/snntorch/pull/111
    • @MegaYEye made their first contribution in https://github.com/jeshraghian/snntorch/pull/118

    Full Changelog: https://github.com/jeshraghian/snntorch/compare/v0.5.1...v0.5.2

    Source code(tar.gz)
    Source code(zip)
  • v0.5.0(Feb 10, 2022)

    What's new?

    • refactored structure of neuron models to make it easier to integrate custom neurons
    • added recurrent Leaky neuron RLeaky
    • added recurrent Synaptic neuron RSynaptic
    • Spiking LSTM neurons added SLSTM
    • Spiking Convolutional 2d LSTMs added SConv2dLSTM
    • learnable thresholds for all neurons
    • learnable explicit recurrence
    • Reset mechanism now includes 'none' as an option
    • update unit tests

    snntorch.surrogate

    • Triangular surrogate
    • Straight through estimator

    snntorch.functional

    • mse_temporal_loss function added Applies mean square error the first F spikes. Option for tolerance included, as well as passing labels to be converted into spike-time targets.

    • ce_temporal_loss added Applies cross entropy loss to an inversion of the first spike. Inversion options include -1 * x and 1/x which means maximizing the logit of the correct class corresponds to minimizing the correct neuron's firing time.

    • accuracy_temporal added Measures accuracy based on the occurrence of the first spike

    Full Changelog: https://github.com/jeshraghian/snntorch/compare/v0.4.11...v0.5.0

    Source code(tar.gz)
    Source code(zip)
  • v0.2.11(May 17, 2021)

    Some of the bugs from the previous versions have now been fixed w.r.t. sizes of tensors in spike encoding.

    What's new?

    snntorch.spikegen

    • Data & target conversion have been separated out
    • Conversion sizes have been fixed
    • Time dimension is only created if tensor is time-varying (i.e., latency will always have time-dimension; rate might not)
    • Latency & rate target conversion
    • interpolation, on/off spike vals, time to first spike, on/off rate options included

    snntorch.surrogate

    • Parameterization of surrogate gradients has been removed from global variable to local variables within closures
    • Spike operator (1/u)
    • Leaky Local spike operator (leaky relu shifted equivalent)
    • Local stochastic spike operator
    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Feb 27, 2021)

    Some of the bugs from the previous versions have now been fixed.

    What's new?

    snntorch

    • SRM0 neuron model fix
    • Reset now applies the threshold rather than '1'
    • Reset by subtraction and reset to zero methods applied to both Stein and SRM0 neurons

    snntorch.spikegen

    • Delta modulation

    snntorch.surrogate

    • Optimized grad calculation

    dev notes

    • Travis-CI is no longer free. Replaced travis.yml with GH actions integration + tox
    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(Feb 11, 2021)

    The first functional iteration of snnTorch!

    What's new?

    snntorch The workhorse of the package. All neuron models are integrated here, and a default Heaviside gradient is used to override the non-differentiability with conventional autograd methods in PyTorch.

    • Stein's neuron model
    • SRM0 neuron model
    • firing inhibition, thanks to @xxwang1
    • hidden states can optionally be initialized as instance variables if the user wants to just use a built-in backprop method

    snntorch.backprop

    • Backprop through time (BPTT)
    • Truncated backprop through time (TBPTT)
    • Real-time recurrent learning (RTRL)

    snntorch.spikegen

    • Poisson spike train generator
    • Rate coding
    • Latency coding

    snntorch.surrogate

    • FastSigmoid
    • Sigmoid
    • Spike Rate Escape

    snntorch.spikeplot

    • Raster plots
    • Feature map animator
    • Spike count animator

    snntorch.utils

    • Data split
    • Data reduction

    Plans for alpha-2

    • delta & delta-sigma spike generators for snntorch.spikegen
    • Simplified Stein's model (reduce hidden states from 2 to 1)
    • More surrogate and backprop methods
    • add more tests
    Source code(tar.gz)
    Source code(zip)
Owner
Jason Eshraghian
neuromorphic engineer
Jason Eshraghian
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

SynSense 21 Dec 14, 2022
A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks

Spiking Neural Network training with EventProp This is an unofficial PyTorch implemenation of EventProp, a method to compute exact gradients for Spiki

Pedro Savarese 35 Jul 29, 2022
Pytorch Implementation of Spiking Neural Networks Calibration, ICML 2021

SNN_Calibration Pytorch Implementation of Spiking Neural Networks Calibration, ICML 2021 Feature Comparison of SNN calibration: Features SNN Direct Tr

Yuhang Li 60 Dec 27, 2022
PyTorch implementation of Spiking Neural Networks trained on surrogate gradient & BPTT using snntorch.

snn-localization repo PyTorch implementation of Spiking Neural Networks trained on surrogate gradient & BPTT using snntorch. Install Dependencies Orig

Sami BARCHID 1 Jan 6, 2022
Spiking Neural Network for Computer Vision using SpikingJelly framework and Pytorch-Lightning

Spiking Neural Network for Computer Vision using SpikingJelly framework and Pytorch-Lightning

Sami BARCHID 2 Oct 20, 2022
Lyapunov-guided Deep Reinforcement Learning for Stable Online Computation Offloading in Mobile-Edge Computing Networks

PyTorch code to reproduce LyDROO algorithm [1], which is an online computation offloading algorithm to maximize the network data processing capability subject to the long-term data queue stability and average power constraints. It applies Lyapunov optimization to decouple the multi-stage stochastic MINLP into deterministic per-frame MINLP subproblems and solves each subproblem via DROO algorithm. It includes:

Liang HUANG 87 Dec 28, 2022
A framework that constructs deep neural networks, autoencoders, logistic regressors, and linear networks

A framework that constructs deep neural networks, autoencoders, logistic regressors, and linear networks without the use of any outside machine learning libraries - all from scratch.

Kordel K. France 2 Nov 14, 2022
Complex-Valued Neural Networks (CVNN)Complex-Valued Neural Networks (CVNN)

Complex-Valued Neural Networks (CVNN) Done by @NEGU93 - J. Agustin Barrachina Using this library, the only difference with a Tensorflow code is that y

youceF 1 Nov 12, 2021
Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks

Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks. Bayesian-Torch is designed to be flexible and seamless in extending a deterministic deep neural network architecture to corresponding Bayesian form by simply replacing the deterministic layers with Bayesian layers.

Intel Labs 210 Jan 4, 2023
DeepHyper: Scalable Asynchronous Neural Architecture and Hyperparameter Search for Deep Neural Networks

What is DeepHyper? DeepHyper is a software package that uses learning, optimization, and parallel computing to automate the design and development of

DeepHyper Team 214 Jan 8, 2023
An implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep Neural Networks in PyTorch.

Neural Attention Distillation This is an implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep

Yige-Li 84 Jan 4, 2023
Vowpal Wabbit is a machine learning system which pushes the frontier of machine learning with techniques such as online, hashing, allreduce, reductions, learning2search, active, and interactive learning.

This is the Vowpal Wabbit fast online learning code. Why Vowpal Wabbit? Vowpal Wabbit is a machine learning system which pushes the frontier of machin

Vowpal Wabbit 8.1k Jan 6, 2023
Try out deep learning models online on Google Colab

Try out deep learning models online on Google Colab

Erdene-Ochir Tuguldur 1.5k Dec 27, 2022
Lighting the Darkness in the Deep Learning Era: A Survey, An Online Platform, A New Dataset

Lighting the Darkness in the Deep Learning Era: A Survey, An Online Platform, A New Dataset This repository provides a unified online platform, LoLi-P

Chongyi Li 457 Jan 3, 2023
Deep Reinforcement Learning for Multiplayer Online Battle Arena

MOBA_RL Deep Reinforcement Learning for Multiplayer Online Battle Arena Prerequisite Python 3 gym-derk Tensorflow 2.4.1 Dotaservice of TimZaman Seed R

Dohyeong Kim 32 Dec 18, 2022
This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CNPs), Neural Processes (NPs), Attentive Neural Processes (ANPs).

The Neural Process Family This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CN

DeepMind 892 Dec 28, 2022
Deep learning (neural network) based remote photoplethysmography: how to extract pulse signal from video using deep learning tools

Deep-rPPG: Camera-based pulse estimation using deep learning tools Deep learning (neural network) based remote photoplethysmography: how to extract pu

Terbe Dániel 138 Dec 17, 2022
Code samples for my book "Neural Networks and Deep Learning"

Code samples for "Neural Networks and Deep Learning" This repository contains code samples for my book on "Neural Networks and Deep Learning". The cod

Michael Nielsen 13.9k Dec 26, 2022