Git Product home page Git Product logo

randwire_tensorflow's Introduction

PWC

RandWire_tensorflow

tensorflow implementation of Exploring Randomly Wired Neural Networks for Image Recognition using Cifar10, MNIST

alt text

Requirements

  • Tensorflow 1.x - GPU version recommended
  • Python 3.x
  • networkx 2.x
  • pyyaml 5.x

Dataset

Please download dataset from this link Both Cifar10 and MNIST dataset are converted into tfrecords format for conveinence. Put train.tfrecords, test.tfrecords files into dataset/cifar10, dataset/mnist

You can create tfrecord file with your own dataset with dataset/dataset_generator.py.

python dataset_generator.py --image_dir ./cifar10/test/images --label_dir ./cifar10/test/labels --output_dir ./cifar10 --output_filename test.tfrecord

Options:

  • --image_dir (str) - directory of your image files. it is recommended to set the name of images to integers like 0.png
  • --label_dir (str) - directory of your label files. it is recommended to set the name of images to integers like 0.txt. label text file must contain class label in integer like 8.
  • --output_dir (str) - directory for output tfrecord file.
  • --outpuf_filename (str) - filename of output tfrecord file.

Experiments

Datasets Model Parameters Accuracy Epoch
CIFAR-10 ResNet110 (Paper) 1.7M 93.57% 300
CIFAR-10 RandWire (my_small_regime) 1.2M 93.64% 60
CIFAR-100 RandWire (my_regime) 8M 74.49% 100

(19.04.18 changed) I trained on Cifar10 dataset and get 6.36 % error on test set. You can download pretrained network from here. Unzip the file and move all files under checkpoint file or your checkpoint directory and try running test script to check the accuracy. The number of parameters used for cifar10 model is aboud 1.2M, which is similar result on ResNet-110 (6.43 %) which used 1.7M parameters.

(19.04.16 added) I trained on Cifar100 dataset and get 74.49% accuracy on test set. You can download pretrained network from same link above.

Training

Cifar 10

python train.py --class_num 10 --image_shape 32 32 3 --stages 4 --channel_count 78 --graph_model ws --graph_param 32 4 0.75 --dropout_rate 0.2 --learning_rate 0.1 --momentum 0.9 --weight_decay 0.0001 --train_set_size 50000 --val_set_size 10000 --batch_size 100 --epochs 100 --checkpoint_dir ./checkpoint --checkpoint_name randwire_cifar10 --train_record_dir ./dataset/cifar10/train.tfrecord --val_record_dir ./dataset/cifar10/test.tfrecord

Options:

  • --class_num (int) - output number of class. Cifar10 has 10 classes.
  • --image_shape (int nargs) - shape of input image. Cifar10 has 32 32 3 shape.
  • --stages (int) - stage (or block) number of randwire network.
  • --channel_count (int) - channel count of randwire network. please refer to the paper
  • --graph_model (str) - currently randwire has 3 random graph models. you can choose from 'er', 'ba' and 'ws'.
  • --graph_param (float nargs) - first value is node count. for 'er' and 'ba', there are one extra parameter so it would be like 32 0.4 or 32 7. for 'ws' there are two extra parameters like above.
  • --learning_rate (float) - initial learning rate
  • --momentum (float) - momentum from momentum optimizer
  • --weight_decay (float) - weight decay factor
  • --train_set_size (int) - number of training data. Cifar10 has 50000 data.
  • --val_set_size (int) - number of validating data. I used test data for validation, so there are 10000 data.
  • --batch_size (int) - size of mini batch
  • --epochs (int) - number of epoch
  • --checkpoint_dir (str) - directory to save checkpoint
  • --checkpoint_name (str) - file name of checkpoint
  • --train_record_dir (str) - file location of training set tfrecord
  • --test_record_dir (str) - file location of test set tfrecord (for validation)

MNIST

python train.py --class_num 10 --image_shape 28 28 1 --stages 4 --channel_count 78 --graph_model ws --graph_param 32 4 0.75 --dropout_rate 0.2 --learning_rate 0.1 --momentum 0.9 --weight_decay 0.0001 --train_set_size 50000 --val_set_size 10000 --batch_size 100 --epochs 100 --checkpoint_dir ./checkpoint --checkpoint_name randwire_mnist --train_record_dir ./dataset/mnist/train.tfrecord --val_record_dir ./dataset/mnist/test.tfrecord

Options:

  • options are same as Cifar10

Cifar100 (19.04.16 added)

python train.py --class_num 100 --image_shape 32 32 3 --stages 4 --channel_count 78 --graph_model ws --graph_param 32 4 0.75 --dropout_rate 0.2 --learning_rate 0.1 --momentum 0.9 --weight_decay 0.0001 --train_set_size 50000 --val_set_size 10000 --batch_size 100 --epochs 100 --checkpoint_dir ./checkpoint --checkpoint_name randwire_cifar100 --train_record_dir ./dataset/cifar100/train.tfrecord --val_record_dir ./dataset/cifar100/test.tfrecord

Options:

  • options are same as Cifar10

Testing

python test.py --class_num --checkpoint_dir ./checkpoint/best --test_record_dir ./dataset/cifar10/test.tfrecord --batch_size 256

Options:

  • --class_num (int) - the number of classes
  • --checkpoint_dir (str) - directory for the checkpoint you want to load and test
  • --test_record_dir (str) - directory for the test dataset
  • --batch_size (int) - batch size for testing

test.py loads network graph and tensors from meta data and evalutes.

Implementation Details

  • Learning rate decreases by multiplying 0.1 in 50% and 75% of entire training step.

  • I made an option init_subsample in my_regime, my_small_regime and small_regime in RandWire.py which do not to use stride 2 for the initial convolutional layer since cifar10 and mnist has low resolution. if you set init_subsample False, then it will use stride 2.

  • While training, it will save the checkpoint with best validation accuracy.

  • While training, it will save tensorboard log for training and validation accuracy and loss in [YOUR_CHECKPOINT_DIRECTORY]/log. You can visualize yourself with tensorboard.

  • I'm currently working on drop connection for regularization and downloading ImageNet dataset to train on my implementation.

  • I added dropout layer after the Relu-Conv-BN triplet unit for regularization. You can set dropout_rate 0.0 to disable it.

  • In train.py, you can use small_regime or regular_regime instead of my_regime and my_small_regime. Both do not use stride 2 in order to prevent subsampling to maintain the spatial information since cifar datasets are not large enough.

  # output logit from NN
  output = RandWire.my_small_regime(images, args.stages, args.channel_count, args.class_num, args.dropout_rate,
                              args.graph_model, args.graph_param, args.checkpoint_dir + '/' + 'graphs', False, training)
  # output = RandWire.small_regime(images, args.stages, args.channel_count, args.class_num, args.dropout_rate,
  #                             args.graph_model, args.graph_param, args.checkpoint_dir + '/' + 'graphs', False,
  #                             training)
  # output = RandWire.regular_regime(images, args.stages, args.channel_count, args.class_num, args.dropout_rate,
  #                             args.graph_model, args.graph_param, args.checkpoint_dir + '/' + 'graphs', training)

randwire_tensorflow's People

Contributors

plemeri avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

randwire_tensorflow's Issues

Environment issues

Hi, could you tell me what are the requirements you used specifically, especially the tensorflow versions. Thanks a lot.

ba and ws doesn't work in latest env

er works well

> python train.py --class_num 10 --image_shape 32 32 3 --stages 4 --channel_count 78 --graph_model er --graph_param 32 0.2 --dropout_rate 0.2 --learning_rate 0.1 --momentum 0.9 --weight_decay 0.0001 --train_set_size 50000 --val_set_size 10000 --batch_size 100 --epochs 100 --checkpoint_dir ./checkpoint --checkpoint_name randwire_cifar10 --train_record_dir ./dataset/cifar10/train.tfrecord --val_record_dir ./dataset/cifar10/test.tfrecord

my env is

(venv) RandWire_tensorflow(master)> pip list
Package               Version  
--------------------- ---------
absl-py               0.7.1    
apparmor              2.13.1   
asn1crypto            0.24.0   
astor                 0.7.1    
Brlapi                0.6.7    
certifi               2018.1.18
chardet               3.0.4    
cryptography          2.3      
cupshelpers           1.0      
cycler                0.10.0   
decorator             4.4.0    
Django                2.2.1    
gast                  0.2.2    
glinux-homedir-helper 1        
glinux-rebootd        0.1      
goobuntu-config-tools 0.1      
gpg                   1.12.0   
grpcio                1.20.1   
h5py                  2.9.0    
idna                  2.6      
image                 1.5.27   
IPy                   0.83     
Keras-Applications    1.0.7    
Keras-Preprocessing   1.0.9    
keyring               10.6.0   
keyrings.alt          3.0      
kiwisolver            1.1.0    
LibAppArmor           2.13.1   
louis                 3.7.0    
Markdown              3.1      
matplotlib            3.0.3    
meld                  3.18.0   
mock                  3.0.3    
monotonic             1.0      
networkx              2.3      
numpy                 1.16.3   
obno                  39       
onboard               1.4.1    
Pillow                6.0.0    
pip                   19.1     
protobuf              3.7.1    
psutil                5.4.2    
pycairo               1.16.2   
pycrypto              2.6.1    
pycups                1.9.73   
pycurl                7.43.0.2 
pygobject             3.26.1   
pyinotify             0.9.6    
pyOpenSSL             17.5.0   
pyparsing             2.4.0    
pysmbc                1.0.15.6 
python-apt            1.8.3    
python-dateutil       2.8.0    
python-debian         0.1.32   
python-xapp           1.0.0    
python-xlib           0.20     
pytz                  2019.1   
pyxattr               0.6.0    
pyxdg                 0.25     
PyYAML                5.1      
requests              2.20.0   
scour                 0.36     
SecretStorage         2.3.1    
setproctitle          1.1.10   
setuptools            41.0.1   
six                   1.12.0   
sqlparse              0.3.0    
tensorboard           1.13.1   
tensorflow            1.13.1   
tensorflow-estimator  1.13.0   
termcolor             1.1.0    
ufw                   0.35     
urllib3               1.24     
virtualenv            16.5.0   
Werkzeug              0.15.2   
wheel                 0.33.1   
youtube-dl            2018.4.25

other graph generators don't work.

> python train.py --class_num 10 --image_shape 32 32 3 --stages 4 --channel_count 78 --graph_model ws --graph_param 32 4 0.75 --dropout_rate 0.2 --learning_rate 0.1 --momentum 0.9 --weight_decay 0.0001 --train_set_size 50000 --val_set_size 10000 --batch_size 100 --epochs 100 --checkpoint_dir ./checkpoint --checkpoint_name randwire_cifar10 --train_record_dir ./dataset/cifar10/train.tfrecord --val_record_dir ./dataset/cifar10/test.tfrecord
+++++++++++++++++++++++++++++++++++++++++++++++++
[Input Arguments]
class_num -> 10
image_shape -> [32, 32, 3]
stages -> 4
channel_count -> 78
graph_model -> ws
graph_param -> [32.0, 4.0, 0.75]
dropout_rate -> 0.2
learning_rate -> 0.1
momentum -> 0.9
weight_decay -> 0.0001
train_set_size -> 50000
val_set_size -> 10000
batch_size -> 100
epochs -> 100
checkpoint_dir -> ./checkpoint
checkpoint_name -> randwire_cifar10
train_record_dir -> ./dataset/cifar10/train.tfrecord
val_record_dir -> ./dataset/cifar10/test.tfrecord
+++++++++++++++++++++++++++++++++++++++++++++++++
WARNING:tensorflow:From /usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From /usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/network/RandWire.py:120: separable_conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.separable_conv2d instead.
WARNING:tensorflow:From /usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/network/RandWire.py:121: batch_normalization (from tensorflow.python.layers.normalization) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.batch_normalization instead.
WARNING:tensorflow:From /usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/network/RandWire.py:13: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.dropout instead.
WARNING:tensorflow:From /usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/tensorflow/python/keras/layers/core.py:143: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Traceback (most recent call last):
  File "train.py", line 170, in <module>
    main(args)
  File "train.py", line 57, in main
    args.graph_model, args.graph_param, args.checkpoint_dir + '/' + 'graphs', False, training)
  File "/usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/network/RandWire.py", line 126, in my_small_regime
    graph_data = gg.graph_generator(graph_model, graph_param, graph_file_path, 'conv' + str(stage) + '_' + graph_model)
  File "/usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/utils/graph_generator.py", line 8, in graph_generator
    graph = nx.random_graphs.connected_watts_strogatz_graph(*graph_param)
  File "</usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/decorator.py:decorator-gen-548>", line 2, in connected_watts_strogatz_graph
  File "/usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/networkx/utils/decorators.py", line 464, in _random_state
    return func(*new_args, **kwargs)
  File "/usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/networkx/generators/random_graphs.py", line 480, in connected_watts_strogatz_graph
    G = watts_strogatz_graph(n, k, p, seed)
  File "</usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/decorator.py:decorator-gen-546>", line 2, in watts_strogatz_graph
  File "/usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/networkx/utils/decorators.py", line 464, in _random_state
    return func(*new_args, **kwargs)
  File "/usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/networkx/generators/random_graphs.py", line 411, in watts_strogatz_graph
    for j in range(1, k // 2 + 1):
TypeError: 'float' object cannot be interpreted as an integer

> python train.py --class_num 10 --image_shape 32 32 3 --stages 4 --channel_count 78 --graph_model ba --graph_param 32 5 --dropout_rate 0.2 --learning_rate 0.1 --momentum 0.9 --weight_decay 0.0001 --train_set_size 50000 --val_set_size 10000 --batch_size 100 --epochs 100 --checkpoint_dir ./checkpoint --checkpoint_name randwire_cifar10 --train_record_dir ./dataset/cifar10/train.tfrecord --val_record_dir ./dataset/cifar10/test.tfrecord
+++++++++++++++++++++++++++++++++++++++++++++++++
[Input Arguments]
class_num -> 10
image_shape -> [32, 32, 3]
stages -> 4
channel_count -> 78
graph_model -> ba
graph_param -> [32.0, 5.0]
dropout_rate -> 0.2
learning_rate -> 0.1
momentum -> 0.9
weight_decay -> 0.0001
train_set_size -> 50000
val_set_size -> 10000
batch_size -> 100
epochs -> 100
checkpoint_dir -> ./checkpoint
checkpoint_name -> randwire_cifar10
train_record_dir -> ./dataset/cifar10/train.tfrecord
val_record_dir -> ./dataset/cifar10/test.tfrecord
+++++++++++++++++++++++++++++++++++++++++++++++++
WARNING:tensorflow:From /usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From /usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/network/RandWire.py:120: separable_conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.separable_conv2d instead.
WARNING:tensorflow:From /usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/network/RandWire.py:121: batch_normalization (from tensorflow.python.layers.normalization) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.batch_normalization instead.
WARNING:tensorflow:From /usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/network/RandWire.py:13: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.dropout instead.
WARNING:tensorflow:From /usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/tensorflow/python/keras/layers/core.py:143: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Traceback (most recent call last):
  File "/usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/networkx/utils/decorators.py", line 295, in _nodes_or_number
    nodes = list(range(n))
TypeError: 'float' object cannot be interpreted as an integer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "train.py", line 170, in <module>
    main(args)
  File "train.py", line 57, in main
    args.graph_model, args.graph_param, args.checkpoint_dir + '/' + 'graphs', False, training)
  File "/usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/network/RandWire.py", line 126, in my_small_regime
    graph_data = gg.graph_generator(graph_model, graph_param, graph_file_path, 'conv' + str(stage) + '_' + graph_model)
  File "/usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/utils/graph_generator.py", line 12, in graph_generator
    graph = nx.random_graphs.barabasi_albert_graph(*graph_param)
  File "</usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/decorator.py:decorator-gen-552>", line 2, in barabasi_albert_graph
  File "/usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/networkx/utils/decorators.py", line 464, in _random_state
    return func(*new_args, **kwargs)
  File "/usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/networkx/generators/random_graphs.py", line 649, in barabasi_albert_graph
    G = empty_graph(m)
  File "</usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/decorator.py:decorator-gen-22>", line 2, in empty_graph
  File "/usr/local/google/home/dongseong/workspaces/RandWire_tensorflow/venv/lib/python3.6/site-packages/networkx/utils/decorators.py", line 297, in _nodes_or_number
    nodes = tuple(n)
TypeError: 'float' object is not iterable

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.