Git Product home page Git Product logo

pytorch2keras's Introduction

pytorch2keras

Build Status

Pytorch to Keras model convertor. Still beta for now.

Installation

pip install pytorch2keras 

Important notice

In that moment the only PyTorch 0.2 (deprecated) and PyTorch 0.4 (latest stable) are supported.

To use the converter properly, please, make changes in your ~/.keras/keras.json:

...
"backend": "tensorflow",
"image_data_format": "channels_first",
...

From the latest releases, multiple inputs is also supported.

Tensorflow.js

For the proper convertion to the tensorflow.js format, please use a new flag short_names=True.

How to build the latest PyTorch

Please, follow this guide to compile the latest version.

How to use

It's a convertor of pytorch graph to a Keras (Tensorflow backend) graph.

Firstly, we need to load (or create) pytorch model:

class TestConv2d(nn.Module):
    """Module for Conv2d convertion testing
    """

    def __init__(self, inp=10, out=16, kernel_size=3):
        super(TestConv2d, self).__init__()
        self.conv2d = nn.Conv2d(inp, out, stride=(inp % 3 + 1), kernel_size=kernel_size, bias=True)

    def forward(self, x):
        x = self.conv2d(x)
        return x

model = TestConv2d()

# load weights here
# model.load_state_dict(torch.load(path_to_weights.pth))

The next step - create a dummy variable with correct shapes:

input_np = np.random.uniform(0, 1, (1, 10, 32, 32))
input_var = Variable(torch.FloatTensor(input_np))

We're using dummy-variable in order to trace the model.

from converter import pytorch_to_keras
# we should specify shape of the input tensor
k_model = pytorch_to_keras(model, input_var, [(10, 32, 32,)], verbose=True)  

That's all! If all is ok, the Keras model is stores into the k_model variable.

Supported layers

Layers:

  • Linear
  • Conv2d
  • DepthwiseConv2d (with limited parameters)
  • Conv3d
  • ConvTranspose2d
  • MaxPool2d
  • MaxPool3d
  • AvgPool2d
  • Global average pooling (as special case of AdaptiveAvgPool2d)
  • Embedding
  • UpsamplingNearest2d

Reshape:

  • View
  • Reshape
  • Transpose

Activations:

  • ReLU
  • LeakyReLU
  • PReLU (only with 0.2)
  • SELU (only with 0.2)
  • Tanh
  • HardTanh (clamp)
  • Softmax
  • Softplus (only with 0.2)
  • Softsign (only with 0.2)
  • Sigmoid

Element-wise:

  • Addition
  • Multiplication
  • Subtraction

Misc:

  • reduce sum ( .sum() method)

Unsupported parameters

  • Pooling: count_include_pad, dilation, ceil_mode

Models converted with pytorch2keras

  • ResNet18
  • ResNet34
  • ResNet50
  • SqueezeNet (with ceil_mode=False)
  • DenseNet
  • AlexNet
  • Inception
  • SeNet
  • Mobilenet v2

Usage

Look at the tests directory.

License

This software is covered by MIT License.

pytorch2keras's People

Contributors

gmalivenko avatar justasbr avatar pkdogcom avatar s-westphal avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.