Git Product home page Git Product logo

autodiff's Introduction

Automatic Differentiation in 38 lines of F#

Automatic differentiation is a technique for computing the derivative of a function in a computer program. This particular implementation is inspired by Automatic Differentiation in 38 lines of Haskell, which is in turn inspired by Beautiful Differentiation.

Statically resolved type parameters (SRTP) and operator overloading in F# make it possible to differentiate a function that is written almost like a plain old F# function. However, due to the limitations of F#, this implementation is much simpler and less powerful than the original Haskell version. This F# version exists solely as a demonstration and learning tool.

Generic numbers

Let's start with a function that implements a simple numerical computation in F#:

> let f x = sqrt (3.0 * sin x);;
val f: x: float -> float

We can evaluate this function to determine its value at, say, x = 2.0:

> f 2.0;;
val it: float = 1.651633216

But what if we would also like this function to work with other numeric types, such as float32? As currently written, this won't work:

> f 2.0f;;

  f 2.0f;;
  --^^^^

stdin(5,3): error FS0001: This expression was expected to have type
    'float'    
but here has type
    'float32'

F#'s sqrt and sin functions are both already generic, so we just have a find a way to replace the 3.0 literal with a corresponding generic version of the number 3. There's no built-in way to do this in F#, but it's not hard to write ourselves:

module Generic =

    /// Converts an integer to the corresponding generic number.
    let inline fromInt n =
        assert(n > 0)
        LanguagePrimitives.GenericOne
            |> Seq.replicate n
            |> Seq.reduce (+)

This is an extremely impractical solution, but it works fine for our purposes. We can now rewrite our computation as:

let inline f x =
    sqrt (Generic.fromInt 3 * sin x)

And can now evaluate f for any numeric type that supports sqrt and sin:

> f 2.0;;
val it: float = 1.651633216

> f 2.0f;;
val it: float32 = 1.651633143f

Note, for example, that the particular implementation of the * operator that is invoked will vary depending on the numeric type.

Dual numbers

In fact, our generic function will now work for any type that type implements the necessary underlying members, such as +, *, One, Sin, and Sqrt. We can even write our own such type, which implements dual numbers. A dual number is a tuple where the first item is a regular value and the second item is a derivative:

type Dual<'a> = D of value : 'a * derivative : 'a   // simplified

We can perform math on dual numbers, just like regular numbers. For example, multiplication of dual numbers follows the product rule for derivatives:

static member inline (*)(D (x, x'), D (y, y')) =
    D (x * y, y' * x + x' * y)

Once we implement all the required members, we can call our original function above with a suitable dual number:

> f (D (2.0, 1.0));;
val it: Dual<float> = D (1.651633216, -0.3779412092)

This gives us both the value of f and the derivative of f at 2.0!

More examples can be found in the unit tests.

The 38 lines

namespace AutoDiff

open LanguagePrimitives

/// A dual number, consisting of a regular value and a derivative value in tandem.
type Dual<'a
    when 'a : (static member Zero : 'a)
    and 'a : (static member One : 'a)> =

        /// A dual number.
        D of value : 'a * derivative : 'a with

    member inline d.Value = let (D(reg, _)) = d in reg
    member inline d.Deriv = let (D(_, deriv)) = d in deriv

    static member inline Const(x : 'a) = D (x, GenericZero)
    static member inline Zero = Dual.Const(GenericZero<'a>)
    static member inline One = Dual.Const(GenericOne<'a>)

    static member inline (+)(D (x, x'), D (y, y')) = D (x + y, x' + y')
    static member inline (-)(x, y) = x + (-y)
    static member inline (~-)(D (x, x')) = D (-x, -x')

    static member inline (*)(D (x, x'), D (y, y')) = D (x * y, y' * x + x' * y)
    static member inline (/)(D (x, x'), D (y, y')) =
        let deriv =
            (GenericOne / (y * y)) * (y * x' + (-x) * y')
        D (x / y, deriv)

    static member inline Sin(D (x, x')) = D (sin x, x' * cos x)
    static member inline Cos(D (x, x')) = D (cos x, x' * -(sin x))

    static member inline Pow(d, n) = pown d n
    static member inline Exp(D (x, x')) = D (exp x, x' * exp x)
    static member inline Log(D (x, x')) = D (log x, x' * (GenericOne / x))
    static member inline Sqrt(D (x, x')) =
        let two = Seq.reduce (+) [ GenericOne; GenericOne ]   // ugh
        D (sqrt x, x' / (two * sqrt x))

autodiff's People

Contributors

brianberns avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.