Git Product home page Git Product logo

parametric-leaky-integrate-and-fire-spiking-neuron's Introduction

Parametric-Leaky-Integrate-and-Fire-Spiking-Neuron

中文README

This repository contains the origin codes and TensorBoard logs for the paper Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks. The trained models are too large that we can't upload them to this repository. But we used a identical seed during training, and we can ensure that the user can get almost the same accuracy when using our codes to train.

Accuracy

This table shows the accuracy of using PLIF neurons, tau_0=2 and max pooling:

MNIST Fashion-MNIST CIFAR10 N-MNIST CIFAR10-DVS DVS128 Gesture
accuracy-A 97.72% 94.38% 93.50% 99.61% 74.80% 97.57%
accuracy-B 99.63% 93.85% 92.58% 99.57% 69.00% 96.53%

This table shows the accuracy-A of using PLIF/LIF neurons, different tau/tau_0 and average/max pooling:

pooling MNIST Fashion-MNIST CIFAR-10 N-MNIST CIFAR10-DVS DVS128 Gesture
PLIF,tau_0=2 max 99.72% 94.38% 93.5% 99.61% 74.8% 97.57%
PLIF,tau_0=16 max 99.73% 94.65% 93.23% 99.53% 70.5% 92.01%
LIF,tau=2 max 99.69% 94.17% 93.03% 99.64% 73.6% 96.88%
LIF,tau=16 max 99.49% 94.47% 47.5% 99.15% 62.4% 76.74%
PLIF,tau_0=2 avg 99.71% 94.74% 93.3% 99.66% 72.7% 97.22%

Directory Structure

codes contains the origin codes:

models.py defines the networks.

train.py trains models on the training set, tests on the test set alternately, and records the maximum test accuracy, which is the accuracy-A in the paper.

train_val.py splits the origin training set into a new training set and validation set, trains on the new training set, tests on the validation set alternately, and records the test accuracy on the test set only once, with the model achieving the maximum validation accuracy, which is the accuracy-B in the paper.

logs contains A and B directories, which contains TensorBoard logs for different accuracies, respectively.

Dependency

The origin codes uses the old version SpikingJelly. To maximize reproducibility, the user can download the latest SpikingJelly and rollback to the version that we used to train:

git clone https://github.com/fangwei123456/spikingjelly.git
cd spikingjelly
git reset --hard 73f94ab983d0167623015537f7d4460b064cfca1
python setup.py install

Here is the commit information:

commit 73f94ab983d0167623015537f7d4460b064cfca1
Author: fangwei123456 <[email protected]>
Date:   Wed Sep 30 16:42:25 2020 +0800

    增加detach reset的选项

Datasets

The line 64 of train.py, and line 84 of train_val.py defines the dataset path:

dataset_dir = '/userhome/datasets/' + dataset_name

where /userhome/datasets/ is the root path of all datasets.

The root path of all datasets should have the following directory structure:

|-- CIFAR10
|   |-- cifar-10-batches-py
|   `-- cifar-10-python.tar.gz
|-- CIFAR10DVS
|   |-- airplane.zip
|   |-- automobile.zip
|   |-- bird.zip
|   |-- cat.zip
|   |-- deer.zip
|   |-- dog.zip
|   |-- events
|   |-- frames_num_20_split_by_number_normalization_None
|   |-- frog.zip
|   |-- horse.zip
|   |-- ship.zip
|   `-- truck.zip
|-- DVS128Gesture
|   |-- DvsGesture.tar.gz
|   |-- LICENSE.txt
|   |-- README.txt
|   |-- events_npy
|   |-- extracted
|   |-- frames_num_20_split_by_number_normalization_None
|   `-- gesture_mapping.csv
|-- FashionMNIST
|   |-- FashionMNIST
|-- MNIST
|   `-- MNIST
`-- NMNIST
    |-- Test.zip
    |-- Train.zip
    |-- events
    `-- frames_num_10_split_by_number_normalization_None

MNIST, Fashion-MNIST and CIFAR10 dataset can be available from torchvision. For neuromorphic datasets' installation, see

https://spikingjelly.readthedocs.io/zh_CN/0.0.0.0.4/spikingjelly.datasets.html

Running codes

Here are the origin running codes for accuracy-B:

Dataset Running codes
MNIST python ./codes/train_val.py -init_tau 2.0 -use_max_pool -use_plif -device cuda:0 -dataset_name MNIST -log_dir_prefix /userhome/plif_test/logsd -T 8 -max_epoch 1024 -detach_reset
Fashion-MNIST python ./codes/train_val.py -init_tau 2.0 -use_max_pool -use_plif -device cuda:0 -dataset_name FashionMNIST -log_dir_prefix /userhome/plif_test/logsd -T 8 -max_epoch 1024 -detach_reset
CIFAR10 python ./codes/train_val.py -init_tau 2.0 -use_max_pool -use_plif -device cuda:0 -dataset_name CIFAR10 -log_dir_prefix /userhome/plif_test/logsd -T 8 -max_epoch 1024 -detach_reset
N-MNIST python ./codes/train_val.py -init_tau 2.0 -use_max_pool -device cuda:0 -dataset_name NMNIST -log_dir_prefix /userhome/plif_test/logsd -T 10 -max_epoch 1024 -detach_reset -channels 128 -number_layer 2 -split_by number -normalization None -use_plif
CIFAR10-DVS python ./codes/train_val.py -init_tau 2.0 -use_max_pool -device cuda:0 -dataset_name CIFAR10DVS -log_dir_prefix /userhome/plif_test/logsd -T 20 -max_epoch 1024 -detach_reset -channels 128 -number_layer 4 -split_by number -normalization None -use_plif
DVS128 Gesture python ./codes/train_val.py -init_tau 2.0 -use_max_pool -device cuda:0 -dataset_name DVS128Gesture -log_dir_prefix /userhome/plif_test/logsd -T 20 -max_epoch 1024 -detach_reset -channels 128 -number_layer 5 -split_by number -normalization None -use_plif

The code can recovery training from the interruption. It will load the exist model and continue training from the last epoch.

Arguments Definition

This table shows the definition of all arguments:

argument meaning type default
init_tau tau of all LIF neurons, or tau_0 of PLIF neurons float -
batch_size training batch size int 16
learning_rate learning rate float 1e-3
T_max period of the learning rate schedule int 64
use_plif use PLIF neurons action='store_true' False
alpha_learnable if given, alpha in the surrogate function is learnable action='store_true' False
use_max_pool if given, the network will use max pooling, else use average pooling action='store_true' False
device use which device to train str -
dataset_name use which dataset str(MNIST,FashionMNIST,CIFAR10,NMNIST,CIFAR10DVSorDVSGesture) -
log_dir_prefix the path for TensorBoard to save logs str -
T simulating time-step int -
channels the out channels of Conv2d for neuromorphic datasets int -
number_layer the number of Conv2d layers for neuromorphic datasets int -
split_by how to split the events to integrate them to frames str(time ornumber ) -
normalization normalization for frames during being integrated str(frequency,max,norm,sum orNone) -
max_epoch maximum training epoch int -
detach_reset whether detach the voltage reset during backward action='store_true' False

For more details about split_bynormalization, see

https://spikingjelly.readthedocs.io/zh_CN/0.0.0.0.4/spikingjelly.datasets.html#integrate-events-to-frames-init-en

New Implement

SpkingJelly (0.0.0.0.12 or the latest version) has added the network with LIF/max-pooling as an example:

0.0.0.0.12: https://spikingjelly.readthedocs.io/zh_CN/0.0.0.0.12/clock_driven_en/14_classify_dvsg.html

latest: https://spikingjelly.readthedocs.io/zh_CN/latest/activation_based_en/classify_dvsg.html

The codes are written by the new version of SpikingJelly, which are faster than codes in this repository.

All networks in this paper are available at SpikingJelly:

0.0.0.0.12: https://github.com/fangwei123456/spikingjelly/blob/0.0.0.0.12/spikingjelly/clock_driven/model/parametric_lif_net.py

latest: https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/model/sew_resnet.py

Cite

@InProceedings{Fang_2021_ICCV,
    author    = {Fang, Wei and Yu, Zhaofei and Chen, Yanqi and Masquelier, Timothee and Huang, Tiejun and Tian, Yonghong},
    title     = {Incorporating Learnable Membrane Time Constant To Enhance Learning of Spiking Neural Networks},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {2661-2671}
}

parametric-leaky-integrate-and-fire-spiking-neuron's People

Contributors

fangwei123456 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.