Git Product home page Git Product logo

weightwatcher's Introduction

Weight Watcher

the Stat Mech Edition

Current Version: 0.2.7

Weight Watcher analyzes the Fat Tails in the weight matrices of Deep Neural Networks (DNNs).

This tool can predict the trends in the generalization accuracy of a series of DNNs, such as VGG11, VGG13, ..., or even the entire series of ResNet models--without needing a test set !

This relies upon recent research into the Heavy (Fat) Tailed Self Regularization in DNNs

The tool lets one compute a averager capacity, or quality, metric for a series of DNNs, trained on the same data, but with different hyperparameters, or even different but related architectures. For example, it can predict that VGG19_BN generalizes better than VGG19, and better than VGG16_BN, VGG16, etc.

Types of Capacity Metrics:

There are 2 metrics availabe. The average log Norm, which is much faster but less accurate. The average weighted alpha is more accurate but much slower because it needs to both compute the SVD of the layer weight matrices, and thenaa fit the singluar/eigenvalues to a power law.

  • log Norm (default, fast, less accurate)
  • weighted alpaha (slow, more accurate)

Here is an example of the Weighted Alpha capacity metric for all the current pretrained VGG models. alt text

Notice: we did not peek at the ImageNet test data to build this plot.

Frameworks supported

  • Keras
  • PyTorch

Layers supported

  • Dense / Linear / Fully Connected (and Conv1D)
  • Conv2D

Installation

pip install weightwatcher

Usage

Weight Watcher works with both Keras and pyTorch models.

import weightwatcher as ww
watcher = ww.WeightWatcher(model=model)
results = watcher.analyze()

watcher.get_summary()
watcher.print_results()

Advanced Usage

The analyze function has several features described below

def analyze(self, model=None, layers=[], min_size=50, max_size=0,
                alphas=False, softranks=True, spectralnorms=True, 
                mp_fit=True,  plot=False):
...

and in the Demo Notebook

weighted alpha (SLOW)

Power Law fit, here with pyTorch example

import weightwatcher as ww
import torchvision.models as models

model = models.vgg19_bn(pretrained=True)
watcher = ww.WeightWatcher(model=model)
results = watcher.analyze(alphas=True)
data.append({"name": "vgg19bntorch", "summary": watcher.get_summary()})


### data:
{'name': 'vgg19bntorch',
  'summary': {'lognorm': 0.81850576,
   'lognorm_compound': 0.9365272010550088,
   'alpha': 2.9646726379493287,
   'alpha_compound': 2.847975521455623,
   'alpha_weighted': 1.1588882728052485,
   'alpha_weighted_compound': 1.5002343912892515}},

Capacity Metrics (evarages over all layers):

  • lognorm: average log norm, fast

  • alpha_weight: average weighted alpha, slow

  • alpha: average alpha, not weighted (slow, not as useful)

Compound averages:

Same as above, but averages are computed slightly differently. This will be desrcibed in an upcoming paper.

Results are also provided for every layer; see Demo Notebook

Additional options

filter by layer types

results = watcher.analyze(layers=ww.LAYER_TYPE.CONV1D|ww.LAYER_TYPE.DENSE)

filter by ids

results = watcher.analyze(layers=[20])

minimum, maximum size of weight matrix

Sets the minimum and maximum size of the weight matrices analyzed. Setting max is useful for a quick debugging.

results = watcher.analyze(min_size=50, max_size=500)

plots (for weight_alpha=True)

Create log-log plots for each layer weight matrix to observe how well the power law fits work

results = watcher.analyze(compute_alphas=True, plot=True)

Links

Demo Notebook

Calculation Consulting homepage

Calculated Content Blog


Implicit Self-Regularization in Deep Neural Networks: Evidence from Random Matrix Theory and Implications for Learning

Traditional and Heavy Tailed Self Regularization in Neural Network Models

Notebook for above 2 papers (https://github.com/CalculatedContent/ImplicitSelfRegularization)

Recent talk (presented at NERSC Summer 2018)


Heavy-Tailed Universality Predicts Trends in Test Accuracies for Very Large Pre-Trained Deep Neural Networks

Notebook for paper (https://github.com/CalculatedContent/PredictingTestAccuracies)

Latest Talk (presented at UC Berkeley/ICSI 12/13/2018)

ICML 2019 Theoretical Physics Workshop Paper


KDD 2019 Workshop: Statistical Mechanics Methods for Discovering Knowledge from Production-Scale Neural Networks (slides only, video coming soon)

Data Science at Home Podcast

Aggregate Intellect Podcast


Predicting trends in the quality of state-of-the-art neural networks without access to training or testing data

Repo for latest paper

Release

Publishing to the PyPI repository:

# 1. Check in the latest code with the correct revision number (__version__ in __init__.py)
vi weightwatcher/__init__.py # Increse release number, remove -dev to revision number
git commit
# 2. Check out latest version from the repo in a fresh directory
cd ~/temp/
git clone https://github.com/CalculatedContent/WeightWatcher
cd WeightWatcher/
# 3. Use the latest version of the tools
python -m pip install --upgrade setuptools wheel twine
# 4. Create the package
python setup.py sdist bdist_wheel
# 5. Test the package
twine check dist/*
# 6. Upload the package to PyPI
twine upload dist/*
# 7. Tag/Release in github by creating a new release (https://github.com/CalculatedContent/WeightWatcher/releases/new)

License

Apache License 2.0

Contributors

Charles H Martin, PhD Calculation Consulting

Serena Peng

weightwatcher's People

Contributors

charlesmartin14 avatar reserena avatar

Stargazers

5l1v3r1 avatar Lorenzo Campoli avatar

Watchers

James Cloos avatar Lorenzo Campoli avatar paper2code - bot avatar

Forkers

lkampoli 5l1v3r1

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.