Git Product home page Git Product logo

gradients's Introduction

Build your deep learning models with confidence

Build Status codecov PyPI version Code style: black Downloads License DOI

Gradients provide a self consistency test function to perform gradient checking on your deep learning models. It uses centered finite difference approximation method to check the difference between analytical and numerical gradients and report if the check fails on any parameters of your model. Currently the library supports only PyTorch models built with custom layers, custom loss functions, activation functions and any neural network function subclassing AutoGrad.

Installation

pip install gradients

Package Overview

Optimizing deep learning models is a two step process:

  1. Compute gradients with respect to parameters

  2. Update the parameters given the gradients

In PyTorch, step 1 is done by the type-based automatic differentiation system torch.nn.autograd and 2 by the package implementing optimization algorithms torch.optim. Using them, we can develop fully customized deep learning models with torch.nn.Module and test them using Gradient as follows;

Activation function with backward

class MySigmoid(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        output = 1/(1+torch.exp(-input))
        ctx.save_for_backward(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        return grad_output*input*(1-input)

Loss function with backward

class MSELoss(torch.autograd.Function):

    @staticmethod
    def forward(ctx, y_pred, y):
        ctx.save_for_backward(y_pred, y)
        return ((y_pred-y)**2).sum()/y_pred.shape[0]

    @staticmethod
    def backward(ctx, grad_output):
        y_pred, y = ctx.saved_tensors
        grad_input = 2 * (y_pred-y)/y_pred.shape[0]
        return grad_input, None

Pytorch Model

class MyModel(torch.nn.Module):
    def __init__(self,D_in, D_out):
        super(MyModel,self).__init__()
        self.w1 = torch.nn.Parameter(torch.randn(D_in, D_out), requires_grad=True)
        self.sigmoid = MySigmoid.apply
    def forward(self,x):
        y_pred = self.sigmoid(x.mm(self.w1))
        return y_pred

Check your implementation using Gradient

import torch
from gradients import Gradient

N, D_in, D_out = 10, 4, 3

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct model by instantiating the class defined above
mymodel = MyModel(D_in, D_out)
criterion = MSELoss.apply

# Test custom build model
Gradient(mymodel,x,y,criterion,eps=1e-8)

gradients's People

Contributors

saran-nns avatar

Watchers

 avatar  avatar

gradients's Issues

RNN model support

Develop and test an example of RNN models with custom kernels and nn.RNN kernels

CNN model support

Develop and test an example of CNN models with custom kernels and nn.Conv1D and nn.Conv2d kernels

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.