Git Product home page Git Product logo

residualattentionnetwork-pytorch's Introduction

ResidualAttentionNetwork-pytorch

A pytorch code about Residual Attention Network.

This code is based on two projects from

https://github.com/liudaizong/Residual-Attention-Network and https://github.com/fwang91/residual-attention-network/blob/master/imagenet_model/Attention-92-deploy.prototxt

The first project is the pytorch code, but i think some network detail is not good. So I modify it according to the architechure of the Attention-92-deploy.prototxt.

And I also add the ResidualAttentionModel_92 for training imagenet, ResidualAttentionModel_448input for larger image input, and ResidualAttentionModel_92_32input_update for training cifar10.

paper referenced

Residual Attention Network for Image Classification (CVPR-2017 Spotlight) By Fei Wang, Mengqing Jiang, Chen Qian, Shuo Yang, Chen Li, Honggang Zhang, Xiaogang Wang, Xiaoou Tang

how to train?

first, download the data from http://www.cs.toronto.edu/~kriz/cifar.html make sure the varible

is_train = True

CUDA_VISIBLE_DEVICES=0 python train.py

CUDA_VISIBLE_DEVICES=0 python train_mixup.py(with mixup)

you can train on ResidualAttentionModel_56 or ResidualAttentionModel_448input, only should modify the code in train.py from "from model.residual_attention_network import ResidualAttentionModel_92 as ResidualAttentionModel" to "from model.residual_attention_network import ResidualAttentionModel_56 as ResidualAttentionModel"

how to test?

make sure the varible

is_train = False

CUDA_VISIBLE_DEVICES=0 python train.py

CUDA_VISIBLE_DEVICES=0 python train_mixup.py(with mixup)

result

  1. cifar-10: Acc-95.4(Top-1 err 4.6) with ResidualAttentionModel_92_32input_update(higher than paper top-1 err 4.99)

  2. cifar-10: Acc-96.65(Top-1 err 3.35) with ResidualAttentionModel_92_32input_update(with mixup).

  3. cifar-10: Acc-96.84(Top-1 err 3.16) with ResidualAttentionModel_92_32input_update(with mixup, with simpler attention module).

Thanks to @PistonY, who give me the advice of mixup. More details for mixup you can reference the project https://github.com/facebookresearch/mixup-cifar10

the paper only give the archietcture details of attention_92 for imagenet with 224 input but not for cifar10. So I build the net following my understanding. I have not struggled for optimizing the code, so maybe you can do better based my code.

model file:

model_92_sgd.pkl is the trained model file, accuracy of 0.954

residualattentionnetwork-pytorch's People

Contributors

tengshaofeng avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.