Git Product home page Git Product logo

spectral-density's Introduction

Large Scale Spectral Density Estimation for Deep Neural Networks

This repository contains two implementations of the stochastic Lanczos Quadrature algorithm for deep neural networks as used and described in Ghorbani, Krishnan and Xiao, An Investigation into Neural Net Optimization via Hessian Eigenvalue Density (ICML 2019).

To run the example notebooks, please first pip install tensorflow_datasets.

TensorFlow Implementation

The main class that runs distributed Lanczos algorithm is LanczosExperiment. The Jupyter notebook demonstrates how to use this class.

In addition to single machine (potentially multiple-GPU setups), this implementation is also suitable for multi-GPU multi-worker setups. The crucial step is manually partitioning the input data across the available GPUs.

The algorithm outputs two numpy files: tridiag_1 and lanczos_vec_1 which are the tridiagonal matrix and Lanczos vectors. The tridiagonal matrix can then be used to generate spectral densities using tridiag_to_density.

Jax Implementation (by Justin Gilmer)

The Jax version is fantastic for fast experimentation (especially in conjunction with trax). The Jupyter notebook demonstrates how to run Lanczos in Jax.

The main function is lanczos_alg, which returns a tridiagonal matrix and Lanczos vectors. The tridiagonal matrix can then be used to generate spectral densities using tridiag_to_density.

Differences between implementations

  1. The TensorFlow version performs Hessian-vector product accumulation and the actual Lanczos algorithm in float64, whereas the Jax version performs all calculation in float32.
  2. The TensorFlow version targets multi-worker distributed setups, whereas the Jax version targets single worker (potentially multi-GPU) setups.

This is not an official Google product.

spectral-density's People

Contributors

fabianp avatar jmgilmer avatar pforet avatar yingusxiaous avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

spectral-density's Issues

Version of Tensorflow

Hi,

Could you please specify which version of tf should be used, in order to run everything?

Thanks!

NaNs

Hi,

I am getting NaNs in the output for a CNN. It seems to work fine for an RNN (although the EValues are extremely high..any idea why?).
Do you have an idea what could be the reason for getting NaNs?

I am using the Jax version.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.