Git Product home page Git Product logo

minijyro.jl's Introduction

Minijyro

A simple probabilistic programming language in Julia based on effect handlers. The name comes from the fact that this project is largely based on the ideas from Pyro's effect handlers and their Mini-Pyro implementation.

The design goals of this language are:

  • Allow for concise definition of sample statements using ~ syntax
  • Use effect handlers to implement simple operations such as conditioning and computing the log joint probability
  • Leverage existing Julia packages such as Distributions.jl, AdvancedHMC and Flux

NOTE: This is not meant to be a serious PPL to be used by anyone. If you are interested in probabilistic programming in Julia have a look at Turing.jl, Gen and Soss.jl.

Example: Bayesian Linear Regression

A simple model taken from Colin Caroll's tour of PPL APIs.

using Distributions
using LinearAlgebra: I # Identity matrix
using Random

using Minijyro

Random.seed!(42)

# Generate some data.
N = 100
D = 5
true_w = randn(D)
X = randn(N, D)
noise = 0.1 * randn(N)
y_obs = X * true_w + noise

@jyro function model(xs)
    D = size(xs)[2]
    w ~ MvNormal(zeros(D), I)
    y ~ MvNormal(xs * w, 0.1*I)
end

cond_model = condition(model, Dict(:y => y_obs))
samples, stats = nuts(cond_model, (X,), 1000)

@show abs.(true_w - mean(samples))

Behind the Scenes

Here is a high-level overview of the inner workings of Minijyro. For more details I recommend first reading through the links to the Pyro documentation from above and then through full the source code of Minijyro.

Minijyro models are normal Julia functions which are annotated with the @jyro macro. The macro does some source code transformations and translates the function to a MinijyroModel type. See dsl.jl for the full implementation of the @jyro macro.

Most importantly the @jyro macro translates each ~ expression into a call to

function sample!(
    trace::Dict,
    handlers_stack::Array{AbstractHandler,1},
    name::Any,
    dist::Distributions.Distribution
)
    if length(handlers_stack) == 0
        return rand(dist)
    end

    initial_msg = Dict(
        :fn => rand,
        :args => (dist, ),
        :name => name,
        :value => nothing,
        :is_observed => false,
        :done => false,
        :stop => false,
        :continuation => nothing
    )
    msg = apply_stack!(trace, handlers_stack, initial_msg)
    return msg[:value]
end

sample! basically samples a random value from dist. Crucially, any side effects of this sampling (e.g. computing the log density or saving the sampled value in trace) can be conveniently defined as effect handlers. The function apply_stack! is used to apply all effect handlers that are active at the given sample site:

function apply_stack!(
    trace::Dict,
    handlers_stack::Array{AbstractHandler,1},
    msg::Dict
)
    @assert length(handlers_stack) > 0

    handler_counter = 0
    # Loop through handlers from top of the stack to the bottom.
    for handler in handlers_stack[end:-1:1]
        handler_counter += 1
        process_message!(trace, handler, msg)
        if msg[:stop]
            break
        end
    end

    if !(msg[:value] != nothing || msg[:done])
        msg[:value] = msg[:fn](msg[:args]...)
    end

    # Loop through handlers from bottom of the stack to the top.
    # If we exited the first loop early then we will start looping from the
    # handler which caused the loop to exit.
    for handler in handlers_stack[end-handler_counter+1:end]
        postprocess_message!(trace, handler, msg)
    end

    if msg[:continuation] != nothing
        msg[:continuation](trace, msg)
    end

    return msg
end

Effect handlers are subtypes of AbstractHandler:

abstract type AbstractHandler end

function enter!(trace::Dict, h::AbstractHandler)
    return
end

function exit!(trace::Dict, h::AbstractHandler)
    return
end

function process_message!(trace::Dict, h::AbstractHandler, msg::Dict)
    return
end

function postprocess_message!(trace::Dict, h::AbstractHandler, msg::Dict)
    return
end

For example, conditioning on data can be implemented as:

struct ConditionHandler <: AbstractHandler
    data::Dict
end

function process_message!(trace::Dict, h::ConditionHandler, msg::Dict)
    if haskey(h.data, msg[:name])
        msg[:value] = h.data[msg[:name]]
        msg[:stop] = true
        msg[:is_observed] = true
    end
end

minijyro.jl's People

Contributors

treigerm avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.