This repo started as a fork of Picking Winning Tickets Before Training by Preserving Gradient Flow.
- The code has been refactored to use Hydra for configuration. See the
hydra
directory for all default behavior. Any Hydra config field can be overridden from the command-line.
python 3.6, graphviz
conda create --name grasp-nas-env python=3.6
conda activate grasp-nas-env
sudo apt update
sudo apt install graphviz
pip install requirements.txt
git clone https://github.com/samuelstanton/upcycle.git
pip install -e upcycle
git clone ssh://git.amazon.com/pkg/Grasp-nas
cd Grasp-nas
pip install -e .
# CIFAR-10, GraphNet32.75.50, (75% operation pruning ratio, 50% weight pruning ratio)
$ python scripts/image_classification.py dataset=cifar10 network=graphnet op_pruner.target_ratio=0.75
weight_pruner.target_ratio=0.5
# CIFAR-100, ResNet32.90 (90% weight pruning ratio)
$ python scripts/image_classification.py dataset=cifar100 network=resnet weight_pruner.target_ratio=0.90
For the default behavior of all experiments, please refer to the hydra
directory. Use command-line overrides
debug=True
, subsample_ratio=0.125
, and num_train_epochs=2
for quick testing.
By default, your hydra config and all outputs, including dataframes and checkpoints are logged locally in
config.log_dir
. If you have AWS CLI configured, you can alternatively use logger=s3
to save your results to
s3://${logger.params.bucket_name}/{logger.params.log_dir}
.
- By default, datasets are expected to be in
./data
(e.g../data/CIFAR10
,./data/CIFAR100
). You can override the default dataset directory usingdataset.dataset_dir
. - CIFAR-10 & CIFAR-100 will automatically be downloaded if not present.
- Download tiny imagenet from "https://tiny-imagenet.herokuapp.com", and place it in ../data/TinyImageNet.
Please make sure there will be two folders,
train
andval
, under the directory of./data/TinyImageNet
. In eithertrain
orval
, there will be 200 folders storing the images of each category. Or You can also download the processed data from here. - MNIST is not currently supported, since it has single channel images (TODO)
- ImageNet will no longer automatically download from PyTorch, so it must also be downloaded (TODO add links)
Options for network
- VGG (TODO add ref)
- ResNet (TODO add ref)
- GraphNet
Options for op_pruner.type
grasp_val
(rank operations by GraSP score value, prune largesttarget_percent
)grasp_mag
(rank operations by GraSP score magnitude, prune smallesttarget_percent
)weight_mag
(rank operations by weight magnitude, prune smallesttarget_percent
)random
(prune randomtarget_percent
of operations)
Options for weight_pruner.type
grasp_val
(rank weights by GraSP score value, prune largesttarget_percent
)grasp_mag
(rank weights by GraSP score magnitude, prune smallesttarget_percent
)
To cite this work, please use
@inproceedings{
Wang2020Picking,
title={Picking Winning Tickets Before Training by Preserving Gradient Flow},
author={Chaoqi Wang and Guodong Zhang and Roger Grosse},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://openreview.net/forum?id=SkgsACVKPH}
}