Git Product home page Git Product logo

autoencoder's Introduction

Auto-Encoder for Keras

This project provides a lightweight, easy to use and flexible auto-encoder module for use with the Keras framework.

Auto-encoders are used to generate embeddings that describe inter and extra class relationships. This makes auto-encoders like many other similarity learning algorithms suitable as a pre-training step for many classification problems.

An example of the auto-encoder module being used to produce a noteworthy 99.84% validation performance on the MNIST dataset with no data augmentation and minimal modification from the Keras example is provided.

Installation

Create and activate a virtual environment for the project.

$ virtualenv env
$ source env/bin/activate

To install the module directly from GitHub:

$ pip install git+https://github.com/aspamers/autoencoder

The module will install keras and numpy but no back-end (like tensorflow). This is deliberate since it leaves the module decoupled from any back-end and gives you a chance to install whatever version you prefer.

To install tensorflow:

$ pip install tensorflow

To install tensorflow with gpu support:

$ pip install tensorflow-gpu

To run examples

With the activated virtual environment with the installed python package run the following commands.

To run the mnist baseline example:

$ python examples/mnist_example.py

To run the mnist siamese pretrained example:

$ python examples/mnist_autoencoder_example.py

Usage

For detailed usage examples please refer to the examples and unit test modules. If the instructions are not sufficient feel free to make a request for improvements.

  • Import the module
from autoencoder import AutoEncoder
  • Load or generate some data.
x_train = np.random.rand(100, 3)
x_test = np.random.rand(30, 3)
  • Design an encoder model
def create_encoder_model(input_shape):
    model_input = Input(shape=input_shape)

    encoder = Dense(4)(model_input)
    encoder = BatchNormalization()(encoder)
    encoder = Activation(activation='relu')(encoder)

    return Model(model_input, encoder)
  • Design a decoder model
    def create_decoder_model(embedding_shape):
        embedding_a = Input(shape=embedding_shape)

        decoder = Dense(3)(embedding_a)
        decoder = BatchNormalization()(decoder)
        decoder = Activation(activation='relu')(decoder)

        return Model(embedding_a, decoder)
  • Create an instance of the AutoEncoder class
encoder_model = create_encoder_model(input_shape)
decoder_model = create_decoder_model(encoder_model.output_shape)
autoencoder = AutoEncoder(encoder_model, decoder_model)
  • Compile the model
autoencoder.compile(loss='binary_crossentropy', optimizer=keras.optimizers.adam())
  • Train the model
autoencoder.fit(x_train, x_train,
                validation_data=(x_test, x_test),
                epochs=epochs)

Development Environment

Create and activate a test virtual environment for the project.

$ virtualenv env
$ source env/bin/activate

Install requirements

$ pip install -r requirements.txt

Install the backend of your choice.

$ pip install tensorflow

Run tests

$ pytest tests/test_autoencoder.py

Development container

To set up the vscode development container follow the instructions at the link provided: https://github.com/aspamers/vscode-devcontainer

You will also need to install the nvidia docker gpu passthrough layer: https://github.com/NVIDIA/nvidia-docker

autoencoder's People

Contributors

aspamers avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.