Git Product home page Git Product logo

grasp's Introduction

Neural Architecture Search by Preserving Gradient Flow

This repo started as a fork of Picking Winning Tickets Before Training by Preserving Gradient Flow.

  1. The code has been refactored to use Hydra for configuration. See the hydra directory for all default behavior. Any Hydra config field can be overridden from the command-line.

Requirements

python 3.6, graphviz

conda create --name grasp-nas-env python=3.6
conda activate grasp-nas-env
sudo apt update
sudo apt install graphviz
pip install requirements.txt
git clone https://github.com/samuelstanton/upcycle.git
pip install -e upcycle
git clone ssh://git.amazon.com/pkg/Grasp-nas
cd Grasp-nas
pip install -e .

Quickstart

# CIFAR-10, GraphNet32.75.50, (75% operation pruning ratio, 50% weight pruning ratio)
$ python scripts/image_classification.py dataset=cifar10 network=graphnet op_pruner.target_ratio=0.75 
weight_pruner.target_ratio=0.5

# CIFAR-100, ResNet32.90 (90% weight pruning ratio)
$ python scripts/image_classification.py dataset=cifar100 network=resnet weight_pruner.target_ratio=0.90

For the default behavior of all experiments, please refer to the hydra directory. Use command-line overrides debug=True, subsample_ratio=0.125, and num_train_epochs=2 for quick testing.

Loggers

By default, your hydra config and all outputs, including dataframes and checkpoints are logged locally in config.log_dir. If you have AWS CLI configured, you can alternatively use logger=s3 to save your results to s3://${logger.params.bucket_name}/{logger.params.log_dir}.

Datasets

  1. By default, datasets are expected to be in ./data (e.g. ./data/CIFAR10, ./data/CIFAR100). You can override the default dataset directory using dataset.dataset_dir.
  2. CIFAR-10 & CIFAR-100 will automatically be downloaded if not present.
  3. Download tiny imagenet from "https://tiny-imagenet.herokuapp.com", and place it in ../data/TinyImageNet. Please make sure there will be two folders, train and val, under the directory of ./data/TinyImageNet. In either train or val, there will be 200 folders storing the images of each category. Or You can also download the processed data from here.
  4. MNIST is not currently supported, since it has single channel images (TODO)
  5. ImageNet will no longer automatically download from PyTorch, so it must also be downloaded (TODO add links)

Networks

Options for network

  1. VGG (TODO add ref)
  2. ResNet (TODO add ref)
  3. GraphNet

Pruners

Options for op_pruner.type

  1. grasp_val (rank operations by GraSP score value, prune largest target_percent)
  2. grasp_mag (rank operations by GraSP score magnitude, prune smallest target_percent)
  3. weight_mag (rank operations by weight magnitude, prune smallest target_percent)
  4. random (prune random target_percent of operations)

Options for weight_pruner.type

  1. grasp_val (rank weights by GraSP score value, prune largest target_percent)
  2. grasp_mag (rank weights by GraSP score magnitude, prune smallest target_percent)

Citation

To cite this work, please use

@inproceedings{
Wang2020Picking,
title={Picking Winning Tickets Before Training by Preserving Gradient Flow},
author={Chaoqi Wang and Guodong Zhang and Roger Grosse},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://openreview.net/forum?id=SkgsACVKPH}
}

grasp's People

Contributors

alecwangcq avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.