Git Product home page Git Product logo

nnabla-nas's Introduction

Build status

drawing

Neural Architecture Search for Neural Network Libraries

NNablaNAS is a Python package that provides methods for neural hardware aware neural architecture search for NNabla

  • A top-level graph to define candidate architectures for convolutional neural networks (CNNs)
  • Profilers to measure the hardware demands of neural architectures (latency, number of parameters, etc...)
  • Searcher algorithms to learn the architecture and model parameters (e.g., DartsSearcher and ProxylessNasSearcher)
  • Regularizers (e.g., LatencyEstimator and MemoryEstimator) which can be used to enforce hardware constraints

NNablaNAS aims to make the architecture search research more reusable and reproducible by providing them with a modular framework that they can use to implement new search algorithms and new search spaces while reusing code.

Getting started

Here we show how to install NNablaNAS and build a simple search space.

Installation

For a local installation, run the following code snippet:

git clone [email protected]:sony/nnabla-nas.git
cd nnabla_nas

Install dependencies for NNablaNAS by the following command

pip install -r requirements.txt

Run tests to check for correctness:

pytest .

Examples

The example below shows how to use NNablaNAS.

We construct a search space by relaxing the layer that the network can have. Our search space encodes that the network chooses between Convolution, MaxPooling, and Identity for the first layer.

from collections import OrderedDict

from nnabla_nas import module as Mo
from nnabla_nas.contrib.model import Model


class MyModel(Model):
    def __init__(self):
        self._block = Mo.MixedOp(
            operators=[
                Mo.Conv(in_channels=3, out_channels=3, kernel=(3, 3), pad=(1, 1)),
                Mo.MaxPool(kernel=(3, 3), stride=(1, 1), pad=(1, 1)),
                Mo.Identity()
            ],
            mode='full'
        )
        self._classifier = Mo.Sequential(
            Mo.ReLU(),
            Mo.GlobalAvgPool(),
            Mo.Linear(3, 10)
        )

    def call(self, input):
        out = self._block(input)
        out = self._classifier(out)
        return out

    def get_arch_parameters(self, grad_only=False):
        r"""Returns an `OrderedDict` containing architecture parameters."""
        p = self.get_parameters(grad_only)
        return OrderedDict([(k, v) for k, v in p.items() if 'alpha' in k])

    def get_net_parameters(self, grad_only=False):
        r"""Returns an `OrderedDict` containing model parameters."""
        p = self.get_parameters(grad_only)
        return OrderedDict([(k, v) for k, v in p.items() if 'alpha' not in k])

if __name__ == '__main__':
    net = MyModel()
    print(net)

The tutorials and examples cover additional aspects of NNablaNAS.

Features

The main features of NNablaNAS are

Search spaces

Search spaces are constructed using Modules. Modules are composed of layers, which receive NNabla Variable as input and computes Variable as output. Modules can also contain other Modules, allowing to nest them in a tree structure. One can assign the submodules as regular attributes. All search space components should inherit from nnabla_nas.module.Module and override the call() method. Please refer to nnabla_nas/module/module.py.

from nnabla_nas.model import Model

class MyModule(Module):

    def __init__(self):
        # TODO: write your code here

    def call(self, input):
        # TODO: write your code here

A search space is defined as a Model, which should inherit API from the class nnabla_nas.contrib.model.Model. The base API for Model has two methods, get_arch_parameters() and get_net_parameters() that return the architecture parameters and model parameters, respectively.

from nnabla_nas.contrib.model import Model

class MyModel(Model):

    def get_arch_parameters(self, grad_only=False):
        # TODO: write your code here

    def get_net_parameters(self, grad_only=False):
        # TODO: write your code here

Searcher algorithms

A Searcher interacts with the search space through a simple API. A searcher samples a model from the search space by assigning values to the architecture parameters. The results from sampled architecture is then used to update the architecture parameters of the search space. A searcher also updates the model parameters. A new Searcher should inherit API from nnabla_nas.runner.searcher.search.Searcher. This class has two methods train_on_batch() and valid_on_batch() which should be redefined by users.

from nnabla_nas.runner.searcher.search import Searcher

class MyAlgorithm(Searcher):

    def callback_on_start(self):
        # TODO: write your code here
        
    def train_on_batch(self, key='train'):
        # TODO: write your code here
    
    def valid_on_batch(self):
        # TODO: write your code here
    
    def callback_on_finish(self):
        # TODO: write your code here

There are two searcher algorithms implemented in NNablaNAS, including DartsSearcher and ProxylessNasSearcher.

Logging

When running the architecture search, the evaluations in the search space are logged. We maintain a folder to keep track of the parameters, predictions (e.g., loss, error, number of parameters, and latency). Users can easily monitor the training curves with TensorboardX.

drawing

Visualization

Visualization is useful for debugging and illustrating the search space. One can easily check whether the search space was built correctly.

drawing

Experiments

NNablaNAS has command line interface utility:

usage: main.py [-h] [--context CONTEXT] [--device-id DEVICE_ID]
               [--type-config TYPE_CONFIG] [--search]
               [--algorithm {DartsSearcher,ProxylessNasSearcher,Trainer}]
               [--config-file CONFIG_FILE] [--output-path OUTPUT_PATH]

optional arguments:
  -h, --help            show this help message and exit
  --context CONTEXT, -c CONTEXT
                        Extension module. 'cudnn' is highly recommended.
  --device-id DEVICE_ID, -d DEVICE_ID
                        A list of device ids to use, e.g., `0,1,2,3`. This is
                        only valid if you specify `-c cudnn`.
  --type-config TYPE_CONFIG, -t TYPE_CONFIG
                        Type configuration.
  --search, -s          Whether it is searching for the architecture.
  --algorithm {DartsSearcher,ProxylessNasSearcher,Trainer}, -a {DartsSearcher,ProxylessNasSearcher,Trainer}
                        Which algorithm to use.
  --config-file CONFIG_FILE, -f CONFIG_FILE
                        The configuration file for the experiment.
  --output-path OUTPUT_PATH, -o OUTPUT_PATH
                        Path to save the monitoring log files.

You can start the architecture search using DartsSearcher by the command below

# search DARTS
python main.py --search \
               -f examples/classification/darts/cifar10_search.json  \
               -a DartsSearcher \
               -o log/classification/darts/cifar10/search

For re-training, the model using the architecture found in the architecture search, just run

# train DARTS
python main.py -f examples/classification/darts/cifar10_train.json \
               -a Trainer \
               -o log/classification/darts/cifar10/train

Documentation

To build documentation in various formats, you will need Sphinx and the readthedocs theme.

cd docs/
pip install -r requirements.txt

You can then build the documentation by running make <format> from the docs/ folder. Run make to get a list of all available output formats.

Contact

NNablaNAS is currently maintained by SSG-DL group, R&D Center Europe Stuttgart Laboratory 1. For bug reports, questions, and suggestions, use Github issues.

License

NNablaNAS is Apache-style licensed, as found in the LICENSE file.

nnabla-nas's People

Contributors

fabiencardinaux avatar kazukiyoshiyama-sony avatar qiiajia avatar te-bacnguyencong avatar te-javieralonsogarcia avatar te-lukasmauch avatar te-stephentiedemann avatar tomonobutsujikawa avatar yuikosakuma1 avatar yukiooobuchi avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.