Git Product home page Git Product logo

matrix-capsules-em-tensorflow's Introduction

MATRIX CAPSULES EM-Tensorflow

License completion

A Tensorflow implementation of CapsNet based on paper Matrix Capsules with EM Routing

Status:

  1. With configuration A=32, B=8, C=16, D=16, batch_size=128, the code can work on a Tesla P40 GPU at a speed of 8s/iteration. The definitions of A-D can be referred to the paper.
  2. With configuration A=B=C=D=32, batch_size=64, the code can work on a Tesla P40 GPU at a speed of 25s/iteration. More optimization on implementation structure is required.
  3. Some modification and optimization is implemented to prompt the numerical stability of GMM. Specific explanations can be found in the code.
  4. With configuration A=32, B=4, D=4, D=4, batch_size=128, each iteration of training takes around 0.6s on a Tesla P40 GPU.

Current Results on smallNORB:

  • Configuration: A=32, B=8, C=16, D=16, batch_size=50, iteration number of EM routing: 2, with Coordinate Addition, spread loss, batch normalization

  • Training loss. Variation of loss is suppressed by batch normalization. However there still exists a gap between our best results and the reported results in the original paper. spread loss

  • Test accuracy(current best result is 87.66%) test_acc

Ablation Study on smallNORB:

  • Configuration: A=32, B=8, C=16, D=16, batch_size=32, iteration number of EM routing: 2, with Coordinate Addition, spread loss, test accuracy is 79.8%.

Current Results on MNIST:

  • Configuration: A=32, B=8, C=16, D=16, batch_size=50, iteration number of EM routing: 2, with Coordinate Addition, spread loss, batch normalization, reconstruction loss.

  • Training loss. spread loss

  • Test accuracy(current best result is 99.3%, only 10% samples are used in test) test_acc

Ablation Study on MNIST:

  • Configuration: A=32, B=4, C=4, D=4, batch_size=128, iteration number of EM routing: 2, no Coordinate Addition, cross entropy loss, test accuracy is 96.4%.
  • Configuration: A=32, B=4, C=4, D=4, batch_size=128, iteration number of EM routing: 2, with Coordinate Addition, cross entropy loss, test accuracy is 96.8%.
  • Configuration: A=32, B=8, C=16, D=16, batch_size=32, iteration number of EM routing: 2, with Coordinate Addition, spread loss 99.1%.

To Do List:

  1. Experiments on smallNORB as in paper is about to be casted.

Any questions and comments to the code and the original algorithms are welcomed!!! My email: zhangsuofei at njupt.edu.cn

Requirements

  • Python >= 3.4
  • Numpy
  • Tensorflow >= 1.2.0

pip install -r requirement.txt

Usage

Step 1. Clone this repository with git.

$ git clone https://github.com/www0wwwjs1/Matrix-Capsules-EM-Tensorflow.git
$ cd Matrix-Capsules-EM-Tensorflow

Step 2. Download the MNIST dataset, mv and extract it into data/mnist directory.(Be careful the backslash appeared around the curly braces when you copy the wget command to your terminal, remove it)

$ mkdir -p data/mnist
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/{train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz}
$ gunzip data/mnist/*.gz

To install smallNORB, follow instructions in ./data/README.md

Step 3. Start the training(MNIST):

$ python3 train.py "mnist"

Start the training(smallNORB):

$ python3 train.py "smallNORB"

Step 4. View the status of training:

$ tensorboard --logdir=./logdir/train_log/

Open the url tensorboard has shown.

Step 5. Start the test on MNIST:

$ python3 eval.py "mnist"

Start the test on smallNORB:

$ python3 eval.py "smallNORB"

Step 6. View the status of test:

$ tensorboard --logdir=./test_logdir

Open the url tensorboard has shown.

Reference

matrix-capsules-em-tensorflow's People

Contributors

www0wwwjs1 avatar yhyu13 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.