Neural Architecture Search for Spiking Neural Networks
Pytorch implementation code for [Neural Architecture Search for Spiking Neural Networks] (https://arxiv.org/abs/2201.10355)
For getting knowledge on NAS without training in ANN domain (refer: https://github.com/BayesWatch/nas-without-training)
Prerequisites
- Python 3.9
- PyTorch 1.10.0
- NVIDIA GPU (>= 12GB)
- CUDA 10.2 (optional)
Getting Started
Conda Environment Setting
conda create -n SNASNet
conda activate SNASNet
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
pip install scipy
https://github.com/fangwei123456/spikingjelly)
Spikingjelly Installation (ref:git clone https://github.com/fangwei123456/spikingjelly.git
cd spikingjelly
python setup.py install
Training and testing
- Arguments required for training and testing are contained in ``config.py```
- Here is an example of running an experiment on CIFAR100
- (if a user want to skip search process and use predefined architecgtur) A architecture can be parsed by
--cnt_mat 0302 0030 3003 0000
format
Example) Architecture and the corresponding connection matrix
Training
- Run the following command
python search_snn.py --exp_name 'cifar100_backward' --dataset 'cifar100' --celltype 'backward' --batch_size 32 --num_search 5000
simple argument instruction
--exp_name: savefile name
--dataset: dataset for experiment
--celltype: find backward connections or forward connections
--num_search: number of architecture candidates for searching
Testing (on pretrained model)
- As a first step, download pretrained parameters (link) to
./savemodel/save_cifar100_bw.pth.tar
-
The above pretrained model is for CIFAR100 / architecture
--cnt_mat 0302 0030 3003 0000
-
Run the following command
python search_snn.py --dataset 'cifar100' --cnt_mat 0302 0030 3003 0000 --savemodel_pth './savemodel/save_cifar100_bw.pth.tar' --celltype 'backward'