Git Product home page Git Product logo

importance-sampling's Introduction

Importance Sampling

This python package provides a library that accelerates the training of arbitrary neural networks created with Keras using importance sampling.

# Keras imports

from importance_sampling.training import ImportanceTraining

x_train, y_train, x_val, y_val = load_data()
model = create_keras_model()
model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

ImportanceTraining(model).fit(
    x_train, y_train,
    batch_size=32,
    epochs=10,
    verbose=1,
    validation_data=(x_val, y_val)
)

model.evaluate(x_val, y_val)

Importance sampling for Deep Learning is an active research field and this library is undergoing development so your mileage may vary.

Relevant Research

Ours

  • Not All Samples Are Created Equal: Deep Learning with Importance Sampling [preprint]
  • Biased Importance Sampling for Deep Neural Network Training [preprint]

By others

  • Stochastic optimization with importance sampling for regularized loss minimization [pdf]
  • Variance reduction in SGD by distributed importance sampling [pdf]

Dependencies & Installation

Normally if you already have a functional Keras installation you just need to pip install keras-importance-sampling.

  • Keras > 2
  • A Keras backend among Tensorflow, Theano and CNTK
  • blinker
  • numpy
  • matplotlib, seaborn, scikit-learn are optional (used by the plot scripts)

Documentation

The module has a dedicated documentation site but you can also read the source code and the examples to get an idea of how the library should be used and extended.

Examples

In the examples folder you can find some Keras examples that have been edited to use importance sampling.

Code examples

In this section we will showcase part of the API that can be used to train neural networks with importance sampling.

# Import what is needed to build the Keras model
from keras import backend as K
from keras.layers import Dense, Activation, Flatten
from keras.models import Sequential

# Import a toy dataset and the importance training
from importance_sampling.datasets import MNIST
from importance_sampling.training import ImportanceTraining


def create_nn():
    """Build a simple fully connected NN"""
    model = Sequential([
        Flatten(input_shape=(28, 28, 1)),
        Dense(40, activation="tanh"),
        Dense(40, activation="tanh"),
        Dense(10),
        Activation("softmax") # Needs to be separate to automatically
                              # get the preactivation outputs
    ])

    model.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["accuracy"]
    )

    return model


if __name__ == "__main__":
    # Load the data
    dataset = MNIST()
    x_train, y_train = dataset.train_data[:]
    x_test, y_test = dataset.test_data[:]

    # Create the NN and keep the initial weights
    model = create_nn()
    weights = model.get_weights()

    # Train with uniform sampling
    K.set_value(model.optimizer.lr, 0.01)
    model.fit(
        x_train, y_train,
        batch_size=64, epochs=10,
        validation_data=(x_test, y_test)
    )

    # Train with importance sampling
    model.set_weights(weights)
    K.set_value(model.optimizer.lr, 0.01)
    ImportanceTraining(model).fit(
        x_train, y_train,
        batch_size=64, epochs=2,
        validation_data=(x_test, y_test)
    )

Using the script

The following terminal commands train a small VGG-like network to ~0.65% error on MNIST (the numbers are from a CPU). .. code:

$ # Train a small cnn with mnist for 500 mini-batches using importance
$ # sampling with bias to achieve ~ 0.65% error (on the CPU).
$ time ./importance_sampling.py \
>   small_cnn \
>   oracle-gnorm \
>   model \
>   predicted \
>   mnist \
>   /tmp/is \
>   --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \
>   --train_for 500 --validate_every 500
real    1m41.985s
user    8m14.400s
sys     0m35.900s
$
$ # And with uniform sampling to achieve ~ 0.9% error.
$ time ./importance_sampling.py \
>   small_cnn \
>   oracle-loss \
>   uniform \
>   unweighted \
>   mnist \
>   /tmp/uniform \
>   --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \
>   --train_for 3000 --validate_every 3000
real    9m23.971s
user    47m32.600s
sys     3m4.188s

importance-sampling's People

Contributors

angeloskath avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.