Spiking Neural Network training with EventProp
This is an unofficial PyTorch implemenation of EventProp, a method to compute exact gradients for Spiking Neural Networks. The repo currently contains code to train a 1-layer Spiking Neural Network with leaky integrate-and-fire (LIF) neurons for 10-way digit classification on MNIST.
Implementation Details
The implementation of EventProp itself is in models.py, in form of the forward and backward methods of the SpikingLinear module, which compute the forward passes of a spiking layer and its adjoint layer.
In particular, the manual_forward method computes the discretized dynamics of a spiking layer:
While the manual_backward method computes the discretized dynamics of the adjoint model, used to compute exact gradients for the weight parameters:
The network is run for a fixed amount of time and discrete time steps are used to approximate the continuous dynamics. These can be set through the T and dt arguments when running main.py (default values are T=40ms and dt=1ms, so a total of 40 forward passes are executed for each mini-batch).
To encode the MNIST dataset as spikes, images were first binarized and black/white pixels were encoded as spikes at times 10/20ms, respectively. The dynamics of one of the 10 output neurons are as follows, for a randomly-initialized network:
where vertical black lines indicate spike times.
Usage
The code was tested with Python 2.7 + PyTorch 0.4 and Python 3.8 + PyTorch 1.4, producing similar results.
To train the SNN with default settings, just run
python main.py
which will automatically download MNIST and train a SNN for 40 epochs with Adam, on gpu.
Check out the available args in main.py to change training settings such as the learning rate, batch size, and SNN-specific parameters such as membrane/synaptic constants and time discretization.
The default hyperparameters result in stable training, reaching around 85% train/test accuracy in under 10 epochs:
Extensions
If there is enough interest, I can try to extend the EventProp implementation to handle hidden layers / convolutions. If you'd like to extend it yourself, feel free to submit a pull request.