Git Product home page Git Product logo

rethinking-bnn-optimization's Introduction

Rethinking Binarized Neural Network Optimization

arXiv:1906.02107 License: Apache 2.0 Code style: black

Implementation for paper "Latent Weights Do Not Exist: Rethinking Binarized Neural Network Optimization".

A poster illustrating the proposed algorithm and its relation to the previous BNN optimization strategy is included at ./poster.pdf.

Note: Bop is now added to Larq, the open source training library for BNNs. We recommend using the Larq implementation of Bop: it is compatible with more versions of TensorFlow and will be more actively maintained.

Requirements

You can also check out one of our prebuilt docker images.

Installation

This is a complete Python module. To install it in your local Python environment, cd into the folder containing setup.py and run:

pip install -e .

Train

To train a model locally, you can use the cli:

bnno train binarynet --dataset cifar10

Reproduce Paper Experiments

Hyperparameter Analysis (section 5.1)

To reproduce the runs exploring various hyperparameters, run:

bnno train binarynet \
    --dataset cifar10 \
    --preprocess-fn resize_and_flip \
    --hparams-set bop \
    --hparams threshold=1e-6,gamma=1e-3

where you use the appropriate values for threshold and gamma.

CIFAR-10 (section 5.2)

To achieve the accuracy in the paper of 91.3%, run:

bnno train binarynet \
    --dataset cifar10 \
    --preprocess-fn resize_and_flip \
    --hparams-set bop_sec52 \

ImageNet (section 5.3)

To reproduce the reported results on ImageNet, run:

bnno train alexnet --dataset imagenet2012 --hparams-set bop
bnno train xnornet --dataset imagenet2012 --hparams-set bop
bnno train birealnet --dataset imagenet2012 --hparams-set bop

This should give the results listed below. Click on the tensorboard icons to see training and validation accuracy curves of the reported runs.

Network Bop - top-1 accuracy
Binary Alexnet 41.1% tensorboard
XNOR-Net 45.9% tensorboard
Bi-Real Net 56.6% tensorboard

rethinking-bnn-optimization's People

Contributors

jamescook106 avatar koenhelwegen avatar lgeiger avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

rethinking-bnn-optimization's Issues

kernel_initializer="glorot_normal"

In birealnet.py, I observe that kernel_initializer is set as "glorot_normal" for Conv2d; However, in QuantConv2D, the kernel_initializer is set as "glorot_normal".
Kernel_initializer ="glorot_normal" means the kernel weights are all set as 1 ?

Thank you very much.

Running Nets with ImageNet2012

Hello everyone,

I've been trying to run the Nets using the ImageNet dataset but the issue is that it cannot be downloaded from the default URL. I have the tars already downloaded but I don't know how to use them. Do you know how?

Thank you

BinaryNet setting in section 5.2

Hi Koen,
Thanks for the great work!
I noticed to reproduce BinaryNet on CIFAR-10 experiment of section 5.2, you use the setting as followed:

class bop_sec52(default):
epochs = 500
batch_size = 50
kernel_quantizer = None
kernel_constraint = None
threshold = 1e-8
gamma = 1e-4
gamma_decay = 0.1
decay_step = int((50000 / 50) * 100)

where kernel_quantizer and kernel_constraint is set to None, so only the input is binarized while the weight is real-valued, is that expected?
I thought BinaryNet should also have binarized weight?

Best,
Junru

some questions about the paper

Thanks for your work! Now i have some questions about the paper.
①"For example, when using SGD and Glorot initialization [11] a learning rate of 1 performs much better than 0.01; but when we multiply the initialized weights by 0.01 before starting training, we obtain the same improvement in performance." But according to the theorem 1: binary weight is invariant if the initial weight and learning rate are scaled a scalar simultaneously , so i want to know why multiply initialized weights by 0.01 will achieve the same improvement as using a learning rate of 1 ? and how does decreasing the learning rate build up inertia?
②how does the moving average filtering out short-lived signals ?
③in algorithm 2, why must have the condition "sign(mi)=sign(wi)" then the weight flip?
@jamescook106 @lgeiger @koenhelwegen

OpKernel wrong attributes

I'm trying to reproduce the results, but I'm stuck on the exception seen below. Would you have a direction to look into?

Running the code with:

  • python 3.6.4
  • CUDA 10.0
  • tensorflow and tensorflow gpu 14.0.0rc0

with a NVIDIA K80

bnno train binarynet \
    --dataset cifar10 \
    --preprocess-fn resize_and_flip \
    --hparams-set bop \
    --hparams threshold=1e-6,gamma=1e-3

throws this exception:

Traceback (most recent call last):
  File "/usr/local/bin/bnno", line 11, in <module>
    load_entry_point('rethinking-bnn-optimization', 'console_scripts', 'bnno')()
  File "/usr/local/lib/python3.6/site-packages/click/core.py", line 764, in __call__
    return self.main(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/click/core.py", line 717, in main
    rv = self.invoke(ctx)
  File "/usr/local/lib/python3.6/site-packages/click/core.py", line 1137, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/usr/local/lib/python3.6/site-packages/click/core.py", line 956, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/usr/local/lib/python3.6/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/zookeeper/cli.py", line 100, in train
    function(build_model, dataset, hparams, output_dir, epochs, **kwargs)
  File "/home/nik/rethinking-bnn-optimization/bnn_optimization/train.py", line 64, in train
    callbacks=callbacks,
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 780, in fit
    steps_name='steps_per_epoch')
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 142, in model_iteration
    input_iterator = _get_iterator(inputs, model._distribution_strategy)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 517, in _get_iterator
    return training_utils.get_iterator(inputs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py", line 1315, in get_iterator
    initialize_iterator(iterator)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py", line 1322, in initialize_iterator
    K.get_session((init_op,)).run(init_op)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: No registered 'Const' OpKernel for GPU devices compatible with node {{node random_crop/Assert/Assert/data_0}}
         (OpKernel was found, but attributes didn't match) Requested Attributes: dtype=DT_STRING, value=Tensor<type: string shape: [] values: Need value.shape >= size, got >
        .  Registered:  device='XLA_CPU_JIT'; dtype in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64, DT_STRING]
  device='XLA_GPU_JIT'; dtype in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, ..., DT_BFLOAT16, DT_HALF, DT_UINT32, DT_UINT64, DT_STRING]
  device='XLA_CPU'; dtype in [DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16]
  device='GPU'; dtype in [DT_VARIANT]
  device='GPU'; dtype in [DT_BOOL]
  device='GPU'; dtype in [DT_COMPLEX128]
  device='GPU'; dtype in [DT_COMPLEX64]
  device='GPU'; dtype in [DT_UINT64]
  device='GPU'; dtype in [DT_INT64]
  device='GPU'; dtype in [DT_QINT32]
  device='GPU'; dtype in [DT_UINT32]
  device='GPU'; dtype in [DT_QUINT16]
  device='GPU'; dtype in [DT_QINT16]
  device='GPU'; dtype in [DT_INT16]
  device='GPU'; dtype in [DT_UINT16]
  device='GPU'; dtype in [DT_QINT8]
  device='GPU'; dtype in [DT_INT8]
  device='GPU'; dtype in [DT_UINT8]
  device='GPU'; dtype in [DT_DOUBLE]
  device='GPU'; dtype in [DT_FLOAT]
  device='GPU'; dtype in [DT_BFLOAT16]
  device='GPU'; dtype in [DT_HALF]
  device='GPU'; dtype in [DT_INT32]
  device='CPU'
  device='XLA_GPU'; dtype in [DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16]

         [[random_crop/Assert/Assert/data_0]]
         [[MakeIterator]]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.