Git Product home page Git Product logo

lrfinder's Introduction

LRFinder for Keras

Learning Rate Finder Callback for Keras. Proposed by Leslie Smith's at https://arxiv.org/abs/1506.01186. Popularized and encouraged by Jeremy Howard in the fast.ai deep learning course.

Usage

from keras.models import Sequential
from keras.layers import Flatten, Dense
from keras.datasets import fashion_mnist
!git clone https://github.com/WittmannF/LRFinder.git
from LRFinder.keras_callback import LRFinder

# 1. Input Data
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

mean, std = X_train.mean(), X_train.std()
X_train, X_test = (X_train-mean)/std, (X_test-mean)/std

# 2. Define and Compile Model
model = Sequential([Flatten(),
                    Dense(512, activation='relu'),
                    Dense(10, activation='softmax')])

model.compile(loss='sparse_categorical_crossentropy', \
              metrics=['accuracy'], optimizer='sgd')


# 3. Fit using Callback
lr_finder = LRFinder(min_lr=1e-4, max_lr=1)

model.fit(X_train, y_train, batch_size=128, callbacks=[lr_finder], epochs=2)

Screen Shot 2019-07-12 at 17 56 36

Last Updates

  • Reload weights when training ends
  • Autoreload model's weights for each change of learning rate
  • For each learning rate, trains the model over batches_lr_update batches
  • Compatible with both model.fit_generator and model.fit method
  • Allow usage of more than one epoch
  • Included momentum to make loss function smoother
  • Number of iterations is automatically inferred as the number of batches (i.e., it will always run over a full epoch)
  • Set of learning rates are spaced evenly on a log scale (a geometric progression) using np.geospace
  • Automatic stop criteria if current_loss > stop_multiplier * lowest_loss
  • stop_multiplier is a linear equation where it is 10 if momentum is 0 and 4 if momentum is 0.9

Next Updates

  • Use exponential annealing instead of np.geospace (start * (end/start) ** pct)
  • Add a test framework

License

  • MIT

lrfinder's People

Contributors

wittmannf avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.