Git Product home page Git Product logo

mnist-from-scratch's Introduction

mnist-from-scratch

Neural network in C++ from scratch for MNIST dataset classification


Build

The build requires c++14

Building should be pretty straightforward, given all dependencies are present on the correct location (which makes the build, actually, pretty not straightforward to setup)

I recommend checking the CMake files to check where the cmake searches for the dependencies. The libraries used by the project include:

It is best to install them via conda.

Do not forget to set (if not already) your CONDA_PREFIX environment variable to the path where your conda is installed (usually this will be set automatically after conda installation, e.g. $HOME/anaconda3 or $HOME/miniconda3), cmake will look for it.

All of which are open sourced, well maintained and well document.

Run the following build command from the project root dir:

cmake -DCMAKE_BUILD_TYPE=Release .  # to configure
cmake --build . --target all -- -j 4

Training

Once the build has completed, there should be build directory in the project root. Also, there should have been two links called mnist and mnist-evaluate. The mnist is the allmighty training and evaluation script. It should take care of downloading the data set, training the model and evaluating it on the test set.

Customizing training parameters

Most training parameters can be customized via environment variables. I recommend checking the source code of model.h header file for most of them.

Some of the parameters include:

  • LEARNING_RATE=3.0
  • TRAIN_EPOCHS=30
  • BATCH_SIZE=10
  • LOSS="mse"

You can also set LOSS to "categorical_cross_entropy", the performance is similar (altho "mse" appears to be slightly better, it might be the case of hyper param tuning) in the long run, however x-entropy seems to converge faster (strikes better performance earlier). In case of cross entropy, I also recommend setting higher BATCH_SIZE. Something like 30` should do.

and some additional logging parameters

  • LOG_STEP_COUNT_STEPS = 30000

It is also possible to continue training, as the model automatically creates checkpoints after each epoch. This is possible by setting the CONTINUE_TRAINING env variable.

To run the training script, execute from the project root:

./mnist

Evaluation

Evaluation happens after the training, but it is still possible to run evaluation on a pre-trained model, which should have been exported into the export/ dir after the training along with its checkpoints.

With the default setting, the model should read over 95% accuracy. That's not exactly state of the art, but it is quite impressive considering the simplicity of the model.

./mnist-evaluate

UPDATE: There has been recently added RUN script, which builds the project automatically and allows to set environment easily and runs the training and evaluation.

Prediction

TODO: This has not been implemented yet. Sorry for the inconvenience.


Final words and Thank Yous

My greatest thanks goes to the QuantStack project for their work on xtensors and their incredible support which I needed quite often during creation of this project.

Some parts of backpropagation algorithm were inspired by the great blog Neural Networks and Deep Learning by Michael Nielsen. The blog has been invaluable to me, thank you very much for this!

An awesome explanation of cross entropy and softmax along with their derivatives was provided by Sefik Ilkin Serengilin this blog


Author: Marek Cermak [email protected]

mnist-from-scratch's People

Contributors

cermakm avatar

Stargazers

 avatar

Watchers

James Cloos avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.