Git Product home page Git Product logo

scikit-chainer's Introduction

scikit-chainer

scikit-learn like interface to chainer

How to install

$ pip install scikit-chainer

what's this?

This is a scikit-learn like interface to the chainer deeplearning framework. You can use it to build your network model and use the model with scikit-learn APIs (e.g. fit, predict) There are ChainerRegresser for regression, ChainerClassifer for classification base classes and ChainerTransformer for transformation. You need to inherit them and implement the following functions,

  1. _setup_network : network definition (FunctionSet, Chain or ChainList of chainer)
  2. forward : emit the result z input x (note this is not the final predicted value)
  3. loss_func: the loss function to minimize (e.g. mean_squared_error, softmax_cross_entropy etc)
  4. output_func : emit the final result y from forwarded values z (e.g. identity for regression and softmax for classification.

Example

Linear Regression

class LinearRegression(ChainerRegresser):
    def _setup_network(self, **params):
        return Chain(l1=F.Linear(params["n_dim"], 1))

    def forward(self, x):
        y = self.network.l1(x)
        return y

    def loss_func(self, y, t):
        return F.mean_squared_error(y, t)

    def output_func(self, h):
        return F.identity(h)

LogisticRegression

class LogisticRegression(ChainerClassifier):
    def _setup_network(self, **params):
        return Chain(l1=F.Linear(params["n_dim"], params["n_class"]))

    def forward(self, x):
        y = self.network.l1(x)
        return y

    def loss_func(self, y, t):
        return F.softmax_cross_entropy(y, t)

    def output_func(self, h):
        return F.softmax(h)

AutoEncoder

class AutoEncoder(ChainerTransformer):
    def __init__(self, activation=F.relu, **params):
        super(ChainerTransformer, self).__init__(**params)
        self.activation = activation

    def _setup_network(self, **params):
        return Chain(
            encoder=F.Linear(params["input_dim"], params["hidden_dim"]),
            decoder=F.Linear(params["hidden_dim"], params["input_dim"])
        )

    def _forward(self, x, train=False):
        z = self._transform(x, train)
        y = self.network.decoder(z)
        return y

    def _loss_func(self, y, t):
        return F.mean_squared_error(y, t)

    def _transform(self, x, train=False):
        return self.activation(self.network.encoder(x))

scikit-chainer's People

Contributors

lucidfrontier45 avatar

Stargazers

 avatar zsh_o avatar Chaojie Ji avatar Josip Delić avatar uetchy avatar Keiji Matsuzaki avatar Swamy Muddu avatar Kay avatar hatayou avatar Toshiaki Toyama avatar Yuting Zhang avatar Kensuke Mitsuzawa avatar Cheng Guo avatar Kristoffer Röshammar avatar pkmr avatar  avatar Kenta Murata avatar Eyad Sibai avatar  avatar takayuki83 avatar Walter Reade avatar Kozo Nishida avatar

Watchers

Kozo Nishida avatar  avatar James Cloos avatar Hiromitsu OTA avatar  avatar  avatar

scikit-chainer's Issues

Exception in BaseChainerEstimator

Hi

test code could not work ..

File "/usr/local/lib/python2.7/site-packages/skchainer/init.py", line 11
class BaseChainerEstimator(base.BaseEstimator, metaclass=ABCMeta):
^
SyntaxError: invalid syntax

please help me..

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.