Git Product home page Git Product logo

capsnet's Introduction

CapsNet

Capsule networks is a novel approach showing promising results on SmallNorb and MNIST. Here we reproduce and build upon the impressive results shown by Sara Sabour et al. We experiment on the Capsule Network architecture by visualizing exactly what the capsules on different layers represents, what information they store about 3D objects in an image, and try to improve its classification results on CIFAR10 and SmallNorb with various methods including some tricks with reconstruction loss. Further, We present a deconvolution-based reconstruction module that reduces the number of learnable parameters by 80% from the fully-connected module presented by Sara Sabour et al.

Benchmarks

Our baseline model is the same as the original paper, but is only trained for 113 epochs on MNIST, and we did not use a 7-model ensemble for CIFAR10 as did in the paper.

Model MNIST SmallNORB CIFAR10
Sabour et al. 99.75% 97.3% 89.40%
Baseline 99.73% 91.5% 72.59%

Experiments

We introduced a deconvolution-based reconstructions module, and experimented with Batch normalization and different network topologies.

Deconvolution-based Reconstruction

The baseline model has 1.4M parameters in the fully connected decoder, while our deconvolution-based reconstruction module recudes the number of learnable parameters by 80% down to 0.25M.

Here is an comparison between the two reconstruction modules after training for 25 epochs on MNIST, where RLoss is the SSE reconstruction loss, and MLoss is the margin loss.

Model RLoss MLoss Accuracy
FC 21.62 0.0058 99.51%
FC w/ BN 13.12 0.0054 99.54%
DeConv 10.87 0.0050 99.54%
DeConv w/ BN 9.52 0.0044 99.55%

Visualization

Reconstructions

Here are the reconstruction results for SmallNORB and CIFAR10, after training for 186 epochs and 86 epochs respectively.

Robustness to Affine Transformations

We visualized how the network recognizes a rotated MNIST image when only trained on unmodified MNIST data. We present an image of number 2 as an example. The network is confident about the result when the image is just slightly rotated, but as the image is further rotated, it starts to confuse the image with other numbers. For example, it is very confident about the image being number 7 at a certain angle, and reconstructs a number 7 that aligns pretty well with the input. Due to its special topological features, the input number 2 is still recognized by the network when rotated by 180°.

Primary Capsules Reconstructions

We used a pre-trained network to train a reconstruction module for Primary Capsules. By scaling these capsules by its routing coefficients to the classified object, we were able to visualize reconstructions from Primary Capsules. Each row is reconstructed from a single capsule, and the routing coefficient is increased from left to right.

Usage

Step 1. Install requirements

  • Python 3
  • PyTorch 1.0.1
  • Torchvision 0.2.1
  • TQDM

Step 2. Adjust hyperparameters

In constants.py:

DEFAULT_LEARNING_RATE = 0.001
DEFAULT_ALPHA = 0.0005 # Scaling factor for reconstruction loss
DEFAULT_DATASET = "small_norb" # 'mnist', 'small_norb'
DEFAULT_DECODER = "FC" # 'FC' or 'Conv'
DEFAULT_BATCH_SIZE = 128
DEFAULT_EPOCHS = 300
DEFAULT_USE_GPU = True
DEFAULT_ROUTING_ITERATIONS = 3

Step 3. Start training

Training with default settings:

$ python train.py

Training flags example:

$ python train.py --decoder=Conv  --file=model32.pt --dataset=mnist

Further help with training flags:

$ python train.py -h

Step 4. Get your results

Trained models are saved in saved_models directory. Tensorboard logs are saved to logs/. You can launch tensorboard with

tensorboard --logdir logs

Future work

  • Fully develop notebooks for visualization and plotting.
  • Implement EM routing.

capsnet's People

Contributors

ethanleet avatar hukkelas avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

capsnet's Issues

Segmentation Fault

I ran your code with default parameters and got a segmentation fault message.

Differences between implementation and original paper

Hi. I've found some difference between what you've implemented and the original model in the paper.
Apart from the reconstruction network being different (which I think is a cool experiment), I've noticed that the 'b' (in dynamic routing) you use is implemented differently.
Originally, the paper seems to suggest using a single 'b' that is independent of the image. You have used one 'b' for the entire minibatch.
Also, the use of a bias term added to 'c' isn't present in the original paper. Was this part of an experiment? I'm still going through the code and haven't tested it beyond MNIST, but since yours one of few repositories that test on SMALLNORB, can I ask what results you've obtained on it.
Thanks!

Primary Capsules reconstruction

Hi, I have a few question about the reconstruction of primary capsules. The idea is really interesting and I would like to understand
better the logic behind it.

You said the following:

"We used a pre-trained network to train a reconstruction module for Primary Capsules. By scaling these capsules by its routing coefficients to the classified object, we were able to visualize reconstructions from Primary Capsules. Each row is reconstructed from a single capsule, and the routing coefficient is increased from left to right."

Could you explain and give more details on how this was done?
Could you also point out the code section that does the primary capsules reconstruction? I looked over the code but I can't figure it out.

Thank you very much for your time.
I'm looking forward to reading your answer.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.