Git Product home page Git Product logo

stochastorch's Introduction

StochasTorch: stochastically rounded operations between Pytorch tensors.

This repository contains a Pytorch software-based implementation of some stochastically rounded operations.

When encoding the weights of a neural network in low precision (such as bfloat16), one runs into stagnation problems: updates end up being too small relative to the numbers the precision of the encoding. This leads to weights becoming stuck and the model's accuracy being significantly reduced.

Stochastic arithmetic lets you perform the addition in such a way that the weights have a non-zero probability of being modified anyway. This avoids the stagnation problem (see figure 4 of "Revisiting BFloat16 Training") without increasing the memory usage (as might happen if one were using a compensated summation to solve the problem).

The downside is that software-based stochastic arithmetic is significantly slower than a normal floating-point addition. It is thus viable for things like the weight update but would not be appropriate in a hot loop.

Usage

This repository introduces the add (x+y) and addcdiv (x + epsilon*t1/t2) operations. They act similarly to their PyTorch counterparts but round the result up or down randomly:

import torch
import stochastorch

# problem definition
size = 10
dtype = torch.bfloat16
x = torch.rand(size, dtype=dtype)
y = torch.rand(size, dtype=dtype)

# deterministic addition
result_det = x + y
print(f"deterministic addition: {result_det}")

# stochastic addition
result_sto = stochastorch.add(x,y)
print(f"stochastic addition: {result_sto}")
difference = result_det - result_sto
print(f"difference: {difference}")

# stochastic addcdiv 
# result = x + epsilon*t1/t2
t1 = torch.rand(size, dtype=dtype)
t2 = torch.rand(size, dtype=dtype)
epsilon = -0.1
result_det = torch.addcdiv(x, t1, t2, value=epsilon)
print(f"deterministic addcdiv: {result_det}")

# stochastic addcdiv
result_sto = stochastorch.addcdiv(x, t1, t2, value=epsilon)
print(f"stochastic addcdiv: {result_bia}")
difference = result_det - result_sto
print(f"difference: {difference}")

Both functions take an optional is_biased boolean parameter. If is_biased is True (the default value), the random number generator is biased according to the relative error of the operation else, it will round up half of the time on average.

When using low precision (16 bits floating-point arithmetic or less), we strongly recommend using the stochastorch.addcdiv function when possible as it is significantly more accurate (note that Pytorch increase the precision locally to 32 bits when computing addcdiv on 16-bits floating point numbers).

Otherwise, it is often beneficial to use higher precision locally then cast down to 16 bits at summing / storage time. add deals with it automatically when its second input is higher precision than the first.

Implementation details

We use TwoSum to measure the numerical error done by the addition, our tests show that it behaves as needed on bfloat16 (some edge cases might be invalid, leading to an inexact computation of the numerical error but, it is reliable enough for our purpose) and higher precisions floating-point types.

This and the nextafter function let us emulate various rounding modes in software (this is inspired by Verrou's backend).

Potential improvements

  • one could implement more operations,
  • one could reduce the memory usage of the operations by using more in-place operations,
  • one could improve the performance of this code by implementing it as a C++/CUDA kernel.

Do not hesitate to submit an issue or a pull request if you need added functionalities for your needs!

Crediting this work

You can use this BibTeX reference if you use StochasTorch within a published work:

@misc{StochasTorch,
  author = {Nestor, Demeure},
  title = {StochasTorch: stochastically rounded operations between Pytorch tensors.},
  year = {2022},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/nestordemeure/stochastorch}}
}

You will find a JAX implementation called Jochastic here.

stochastorch's People

Contributors

nestordemeure avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.