Git Product home page Git Product logo

ems-yolo's Introduction

Deep Directly-Trained Spiking Neural Networks for Object Detection (ICCV2023)

Requirements

The code has been tested with pytorch=1.10.1,py=3.8, cuda=11.3, cudnn=8.2.0_0 . The conda environment can be copied directly via environment.yml. Some additional dependencies can be found in the environment.txt.

Install
$ git clone https://github.com/BICLab/EMS-YOLO.git
$ pip install -r requirements.txt

Pretrained Checkpoints

We provide the best and the last trained model based on EMS-Res34 on the COCO dataset.

detect.py runs inference on a variety of sources, downloading models automatically from the COCO_EMS-ResNet34 .

The relevant parameter files are in the runs/train.

Training & Addition

Train

The relevant code for the Gen1 dataset is at /g1-resnet. It needs to be replaced or added to the appropriate root folder.

For gen1 dataset:

python path/to/train_g1.py --weights ***.pt --img 640

For coco dataset:

python train.py

Calculating the spiking rate:

Dependencies can be downloaded from Visualizer.

python calculate_fr.py

Contact Information

@inproceedings{su2023deep,
  title={Deep Directly-Trained Spiking Neural Networks for Object Detection},
  author={Su, Qiaoyi and Chou, Yuhong and Hu, Yifan and Li, Jianing and Mei, Shijie and Zhang, Ziyang and Li, Guoqi},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={6555--6565},
  year={2023}
}

YOLOv3 is a family of object detection architectures and models pretrained on the COCO dataset, and represents Ultralytics open-source research into future vision AI methods, incorporating lessons learned and best practices evolved over thousands of hours of research and development.

Our code is also implemented in this framework, so please remember to cite their work.

ems-yolo's People

Contributors

biclab avatar qiaoyi-su avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

ems-yolo's Issues

Training Time

Dear Authors,

How long does your method need to train?

.yaml files in the models folder

I'm confused about the .yaml files in the models folder.
I don't know what model does each .yaml file correspond to.
Could anyone help me? Thanks.

DetectMultiBackend is not found

When I run train.py, there is a error:
"ImportError: cannot import name 'DetectMultiBackend' from 'models.common'."
In modles/common.py, class DetectMultiBackend is not found.
Looking forward to a reply!

The provided code doesn't align with the paper

Hi, I found that are some discrepancies between your code and paper.

  1. What is the input for Detect head?
    As claimed in the paper, the last membrane potential of neurons are fed into each detector. However, in your provided code, the input to the Detect layer comes from BasicBlock_ms, which means that the output of BasicBlock_ms is also the convolution of spikes instead of membrane potential.

Could you please explain where you used the membrane potential as the training data for the detection?

  1. Which model is corresponding to the Figure2 in your paper?

In the README.md, you use the ResNet34 model, which I guess should correspond to the resnet34.yaml. However, when I check the content in the resnet34.yaml, It seems that only BasicBlock_2 is used in the backbone, and the MS_block isn't used in the network, which is not consistent with the EMS-Module2 presented in Figure 2.

  1. Where should the MaxPool be?
    In the paper, the Maxpool operation is applied first followed by the LCB block. However, in the Concat_res2 and BasicBlock_ms code, the LCB block is done first, and then the Maxpool operation is performed after concatenation.

  2. The parameters' size is not consistent with the data provided in the paper.
    In Table 2, the authors claim that the EMS-Res10 parameter size is 6.20M. However, according to the provided trained weights, the parameters of ResNet34 are 33.94M. If it is necessary to use ResNet34 to achieve good results, why is Res10 presented in the paper? Could you please provide the trained weights for EMS-Res10?

# the code for load the model information
from models.experimental import attempt_load
from utils.torch_utils import model_info


w_path = './best.pt'
model = attempt_load(w_path)
print(model)
model_info(model)

# output is : Model Summary: 325 layers, 33940542 parameters, 0 gradients, 0.0 GFLOPs

The pretrain model on Gen1 dataset

In your paper, you use EMS-Res10 model and achieve 0.267 mAP on Gen1 Dataset, but I used the framework you provided to train on the Gen1 dataset, I couldn't get good results.
I don't know if there were some problems in my training stage, so could you provide the trained model on Gen1 Dataset?

This repo lacks a lot of important things

I've noticed that several blocks seem to be missing from the repository, and I'm having difficulty locating the EMS-ResNet structure referenced in your paper. Could you possibly guide me on where to find it or provide any additional information? Thank you for your assistance.

Train_g1.py has no dataset about .txt

My gen1 dataset only has .dat and .npy format, but when I run train_g1.py it shows No such file or directory: 'data\train_a\detection_dataset_duration_60s_ratio_1.0\train.txt', could anyone help me, thanks.

Original Traceback (most recent call last): File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker output = module(*input, **kwargs) File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/EMS-YOLO/g1-resnet/models/yolo.py", line 131, in forward input[i] = x RuntimeError: expand(torch.cuda.FloatTensor{[8, 5, 3, 320, 320]}, size=[8, 5, 3, 320]): the number of sizes provided (4) must be greater or equal to the number of dimensions in the tensor (5)

I have create gen1_data by replace train_g1 import dataset_g1T to import give_g1_data, and i have got train.txt and val.txt,but when i run: python train_g1.py,the pre-trained weights is best.pt, cfg file is models/resnet34.yaml.

Missing 'DetectMultiBackend' class in models.common.py

ImportError: cannot import name 'DetectMultiBackend' from 'models.common'
I found the same DetectMultiBackend class in the YOLOv5 repository, which is used for multi-backend inference. But it appears to be incompatible with this version. Can you repush the common.py ?

The Problem of Computing SyOPs in SNN Model

In your paper, I saw that you compared the Energy Efficiency of EMS-ResNet with ANN-Res, MS-Res, and Sew-Res. You provided the firing rate in the paper, but I could not find information on SyOPs, which, according to the formula you provided, $E_b=T\times(fr\times E_{AC}\times OP_{AC}+E_{MAC}\times OP_{MAC}) (6)$ , should be an important indicator of the model's Energy Consumption. I saw potentially related code in your calculate_fr.py:

attention_name = "mem_update.forward"
fr=np.zeros(len(cache[attention_name]))
at_size=np.zeros(len(cache[attention_name]))
for att_index in range(len(cache[attention_name])):
    # visualize_grid_to_grid(save_dir,att_index,attention_name,att_map,image)
    fr[att_index]=cache[attention_name][att_index].sum()/cache[attention_name][att_index].size
for att_index in range(len(cache[attention_name])):
    at_size[att_index]=cache[attention_name][att_index].size
# calculate firing rate
FR.append(fr)
SZ.append(at_size)

SZ seem to be attempting to calculate the number of operations in mem_update, but I encountered some issues when running the code:

Traceback (most recent call last):
File "/home/ps/EMS-YOLO/calculate_fr.py", line 355, in <module>
  main(opt)
File "/home/ps/EMS-YOLO/calculate_fr.py", line 361, in main
  run(**vars(opt))
File "/home/ps/miniconda3/envs/emsyolo/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
  return func(*args, **kwargs)
File "/home/ps/EMS-YOLO/calculate_fr.py", line 235, in run
  fr=np.zeros(len(cache[attention_name]))
KeyError: 'mem_update.forward'

How can I solve this problem?

Gen1 dataset

Hi, @Qiaoyi-Su I'd like to know, how should I use the Gen1 dataset and have you preprocessed the Gen1 dataset? It seems that this repo does not provide fully training support. Looking forward to a reply!

Event Visualization as Point Cloud

Hi,
I'm a novice researcher working on semantic segmentation and I found your paper "Asynchronous Spatio-Temporal Memory Network for Continuous Event-Based Object Detection" very insightful. The visualizations in the paper are excellent and I'd like to include event visualization as point clouds in my own research.

Could you provide any guidance or resources on how to achieve this?
Thanks a lot !

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.