Git Product home page Git Product logo

luiarthur / turingbnpbenchmarks Goto Github PK

View Code? Open in Web Editor NEW
29.0 3.0 1.0 37.59 MB

Benchmarks of Bayesian Nonparametric models in Turing and other PPLs

Home Page: https://luiarthur.github.io/TuringBnpBenchmarks/

License: MIT License

Julia 0.38% Makefile 0.02% Jupyter Notebook 98.60% Python 0.91% R 0.10%
bayesian-inference probabilistic-programming bayesian-nonparametric-models julia-language benchmarks

turingbnpbenchmarks's Introduction

TuringBnpBenchmarks

Benchmarks of Bayesian Nonparametric models in Turing and other PPLs.

This work is funded by GSoC 2020.

My mentors for this project are Hong Ge, Martin Trapp, and Cameron Pfiffer.

Abstract

Probabilistic models, which more naturally quantify uncertainty when compared to their deterministic counterparts, are often difficult and tedious to implement. Probabilistic programming languages (PPLs) have greatly increased productivity of probabilistic modelers, allowing practitioners to focus on modeling, as opposed to the implementing algorithms for probabilistic (e.g. Bayesian) inference. Turing is a PPL developed entirely in Julia and is both expressive and fast due partly to Julia’s just-in-time (JIT) compiler being implemented in LLVM. Consequently, Turing has a more manageable code base and has the potential to be more extensible when compared to more established PPLs like STAN. One thing that may lead to the adoption of Turing is more benchmarks and feature comparisons of Turing to other mainstream PPLs. The aim of this project is to provide a more systematic approach to comparing execution times and features among several PPLs, including STAN, Pyro, nimble, and Tensorflow probability for a variety of Bayesian nonparametric (BNP) models, which are a class of models that provide a much modeling flexibility and often allow model complexity to increase with data size.

To address the need for a more systematic approach for comparing the performance of Turing and various PPLs (STAN, Pyro, nimble, TensorFlow probability) under common Bayesian nonparametric (BNP) models, which are a class of models that provide a great deal of modeling flexibility and allow the number of model parameters, and thus model complexity, to increase with the size of the data. The following models will be implemented (if possible) and timed (both compile times and execution times) in the various PPLs (links to minimum working examples will be provided):

  • Sampling (and variational) algorithms for Dirichlet process (DP) Gaussian / non-Gaussian mixtures for different sample sizes
    • E.g. Sampling via Chinese restaurant process (CRP) representations (including collapsed Gibbs, sequential Monte Carlo, particle Gibbs), HMC/NUTS for stick-breaking (SB) constructions, variational inference for stick-breaking construction.
    • Note: DPs are a popular choice of BNP models typically used when density estimation is of interest. They are also a popular prior for infinite mixture models, where the number of clusters are not known in advance.
  • Sampling (and variational) algorithms for Pitman-Yor process (PYP) Gaussian / non-Gaussian mixtures for different sample sizes
    • E.g. Sampling via generalized CRP representations (including collapsed Gibbs, sequential Monte Carlo, particle Gibbs), HMC/NUTS for stick-breaking (SB) constructions, variational inference for stick-breaking construction.
    • Note: PYPs are generalizations of DPs. That is, DPs are a special case of PYPs. PYPs exhibit a power-law behavior, which enables them to better model heavy-tailed distributions.
  • PYP / DP hierarchical models. Specific model to be determined.

In addition, the effective sample size and inference speed of a standardised setup, e.g. HMC in truncated stick-breaking DP mixture models, for the respective PPLs will be measured.

What this repo contains

This repository includes (or will include) tables and other visualizations that compare the (compile and execution) speed and features of various PPLs (Turing, STAN, Pyro, Nimble, TFP) with a repository containing the minimum working examples (MWEs) for each implementation. Blog posts describing the benchmarks will also be included.

Software / Hardware

All experiments for this project were done in an c5.xlarge AWS Spot Instance. As of this writing, here are the specs for this instance:

  • vCPU: 4 Intel(R) Xeon(R) Platinum 8124M CPU @ 3.00GHz
  • RAM: 8 GB
  • Storage: EBS only
  • Network Bandwidth: Up to 10 Gbps
  • EBS Bandwidth: Up to 4750 Mbps

The following software was used:

  • Julia-v1.4.1. See Project.toml and Manifest.tomal for more info.

turingbnpbenchmarks's People

Contributors

dependabot[bot] avatar luiarthur avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

cameronraysmith

turingbnpbenchmarks's Issues

Can't replicate Turing CRP DPMM example

I am unsuccessful in replicating the infinite GMM example here: https://turing.ml/dev/tutorials/6-infinitemixturemodel/

I'm using the Project.toml and Manifest.toml files from: https://github.com/luiarthur/TuringBnpBenchmarks/

This is what I'm trying to run. It's almost verbatim what's in the tutorial (just replaced some Unicode):

using Turing
import Turing.RandomMeasures.DirichletProcess
import Turing.RandomMeasures.ChineseRestaurantProcess
using Distributions
using PyPlot
import Random
using BenchmarkTools
import StatsBase.countmap

# FIXME: Not working???

# Define model
@model infiniteGMM(x) = begin
    nobs = length(x)
    
    # Hyper-parameters, i.e. concentration parameter and parameters of H.
    alpha = 1.0
    mu0 = 0.0
    sig0 = 1.0
    
    # Define random measure, e.g. Dirichlet process.
    rpm = DirichletProcess(alpha)
    
    # Define the base distribution, i.e. expected value of the Dirichlet process.
    H = Normal(mu0, sig0)
    
    # Latent assignment.
    z = tzeros(Int, nobs)
        
    # Locations of the infinitely many clusters.
    mu = tzeros(Float64, 0)
    
    for i in 1:nobs
        # Number of clusters.
        K = maximum(z)
        nk = Vector{Int}(map(k -> sum(z .== k), 1:K))

        # Draw the latent assignment.
        z[i] ~ ChineseRestaurantProcess(rpm, nk)
        
        # Create a new cluster?
        if z[i] > K
            push!(mu, 0.0)

            # Draw location of new cluster.
            mu[z[i]] ~ H
        end
                
        # Draw observation.
        x[i] ~ Normal(mu[z[i]], 1.0)
    end
end

# Generate data
Random.seed!(1)
data = vcat(randn(10), randn(10) .- 5, randn(10) .+ 10)
data .-= mean(data)
data /= std(data);

# Fit model
Random.seed!(2)
iterations = 1000
model_fun = infiniteGMM(data)
chain = sample(model_fun, SMC(), iterations)

This is the error I'm getting:

BoundsError: attempt to access 0-element Array{Any,1} at index [1]

Stacktrace:
 [1] getindex(::Array{Any,1}, ::Int64) at ./array.jl:788
 [2] ess(::Chains{Union{Missing, Float64},Float64,NamedTuple{(:internals, :parameters),Tuple{Array{String,1},Array{String,1}}},NamedTuple{(),Tuple{}}}; showall::Bool, sections::Array{Symbol,1}, maxlag::Int64, digits::Int64, sorted::Bool) at /home/ubuntu/.julia/packages/MCMCChains/pVyO1/src/stats.jl:376
 [3] summarystats(::Chains{Union{Missing, Float64},Float64,NamedTuple{(:internals, :parameters),Tuple{Array{String,1},Array{String,1}}},NamedTuple{(),Tuple{}}}; append_chains::Bool, showall::Bool, sections::Array{Symbol,1}, etype::Symbol, digits::Int64, sorted::Bool, args::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/ubuntu/.julia/packages/MCMCChains/pVyO1/src/stats.jl:441
 [4] describe(::IJulia.IJuliaStdio{Base.PipeEndpoint}, ::Chains{Union{Missing, Float64},Float64,NamedTuple{(:internals, :parameters),Tuple{Array{String,1},Array{String,1}}},NamedTuple{(),Tuple{}}}; q::Array{Float64,1}, etype::Symbol, showall::Bool, sections::Array{Symbol,1}, digits::Int64, sorted::Bool, args::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/ubuntu/.julia/packages/MCMCChains/pVyO1/src/stats.jl:203
 [5] describe(::IJulia.IJuliaStdio{Base.PipeEndpoint}, ::Chains{Union{Missing, Float64},Float64,NamedTuple{(:internals, :parameters),Tuple{Array{String,1},Array{String,1}}},NamedTuple{(),Tuple{}}}) at /home/ubuntu/.julia/packages/MCMCChains/pVyO1/src/stats.jl:203
 [6] describe(::Chains{Union{Missing, Float64},Float64,NamedTuple{(:internals, :parameters),Tuple{Array{String,1},Array{String,1}}},NamedTuple{(),Tuple{}}}; args::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/ubuntu/.julia/packages/MCMCChains/pVyO1/src/stats.jl:174
 [7] describe at /home/ubuntu/.julia/packages/MCMCChains/pVyO1/src/stats.jl:174 [inlined]
 [8] show(::IOContext{Base.GenericIOBuffer{Array{UInt8,1}}}, ::Chains{Union{Missing, Float64},Float64,NamedTuple{(:internals, :parameters),Tuple{Array{String,1},Array{String,1}}},NamedTuple{(),Tuple{}}}) at /home/ubuntu/.julia/packages/MCMCChains/pVyO1/src/chains.jl:283
 [9] show at ./multimedia.jl:47 [inlined]
 [10] limitstringmime(::MIME{Symbol("text/plain")}, ::Chains{Union{Missing, Float64},Float64,NamedTuple{(:internals, :parameters),Tuple{Array{String,1},Array{String,1}}},NamedTuple{(),Tuple{}}}) at /home/ubuntu/.julia/packages/IJulia/DrVMH/src/inline.jl:43
 [11] display_mimestring(::MIME{Symbol("text/plain")}, ::Chains{Union{Missing, Float64},Float64,NamedTuple{(:internals, :parameters),Tuple{Array{String,1},Array{String,1}}},NamedTuple{(),Tuple{}}}) at /home/ubuntu/.julia/packages/IJulia/DrVMH/src/display.jl:67
 [12] display_dict(::Chains{Union{Missing, Float64},Float64,NamedTuple{(:internals, :parameters),Tuple{Array{String,1},Array{String,1}}},NamedTuple{(),Tuple{}}}) at /home/ubuntu/.julia/packages/IJulia/DrVMH/src/display.jl:96
 [13] #invokelatest#1 at ./essentials.jl:712 [inlined]
 [14] invokelatest at ./essentials.jl:711 [inlined]
 [15] execute_request(::ZMQ.Socket, ::IJulia.Msg) at /home/ubuntu/.julia/packages/IJulia/DrVMH/src/execute_request.jl:112
 [16] #invokelatest#1 at ./essentials.jl:712 [inlined]
 [17] invokelatest at ./essentials.jl:711 [inlined]
 [18] eventloop(::ZMQ.Socket) at /home/ubuntu/.julia/packages/IJulia/DrVMH/src/eventloop.jl:8
 [19] (::IJulia.var"#15#18")() at ./task.jl:358

Create summary table of all PYP benchmarks

The PYP models benchmark for a given sample size can be summarized in the table below, with the cells being compile times (if applicable) and execution times. Timings for forward passes and gradient computations of the models will also be computed if high-level API is available and exposed. The appropriate accompanying data visualizations (line plots, surface plots) will be included to make the results more interpretable. Links will be included to MWEs.

  • Make a dynamic table which changes N. (e.g. N=50, 100, 200, 400, 800, 1600).
  • Make interactive line plots where
    • x-axis is N, and y-axis is execution time or compile time
    • be able to select which implementations / PPL to show on the plot
Model (PPL), sample size = N Collapsed Gibbs SMC Particle Gibbs HMC/NUTS (SB) ADVI (SB)
PYP mixture of Gaussians (Turing)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
PYP mixture of non-Gaussians (Turing)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
PYP mixture of Gaussians (STAN)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
PYP mixture of non-Gaussians (STAN)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
PYP mixture of Gaussians (Nimble)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
PYP mixture of non-Gaussians (Nimble)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
PYP mixture of Gaussians (Pyro)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
PYP mixture of non-Gaussians (Pyro)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
PYP mixture of Gaussians (TFP) NA NA NA
  • TODO
  • TODO
PYP mixture of non-Gaussians (TFP) NA NA NA
  • TODO
  • TODO

Notes:

  • Prefer doing the experiments by column as opposed to by row in this table. This way, if we are run out of time, we have benchmarks across PPLs, for a few implementations; and not benchmarks across a few implementations for one or two PPLs.
  • For now, for fix truncation level for number of stick-breaking components. Time permitting, see if a benchmark can be done to include varying truncation levels.

Error in gradient computation during StatsFuns benchmarks

I'm benchmarking logsumexp and normlogpdf some StatsFuns. But I am running into errors when doing the gradient computation for normlogpdf. Here's some code for reproducing the error.

Here's the environment.

# Project.toml (Julia v1.4.1)

[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

This code snippet shows that somehow I cannot compute the gradients of normlogpdf with respect to the location parameter.

using StatsFuns
using Flux

# Define variables
x, loc, scale = 0.0, 2.0, 1.0

# My implementation of log density of Normal(location, scale), evaluated at x
function my_normlopdf(loc, scale, x)
    z = (x - loc) / scale
    return -z * z * 0.5 - 0.5 * log(2 * pi * scale * scale)
end

# evaluate
my_normlogpdf(loc, scale, x)  # -2.9189385332046727
# gradient
Flux.gradient(mu -> my_normlopdf(mu, scale, x), loc)  # 2.0

# evaluate
normlogpdf(loc, scale, x)  # -2.9189385332046727 (same as above)
# gradient
Flux.gradient(mu -> normlogpdf(mu, scale, x), loc)  # error?!

This is the error being thrown.

ERROR: MethodError: no method matching Irrational{:log2π}(::Int64)
Closest candidates are:
Irrational{:log2π}(::T) where T<:Number at boot.jl:715
Irrational{:log2π}() where sym at irrationals.jl:18
Irrational{:log2π}(::Complex) where T<:Real at complex.jl:37
...
Stacktrace:
[1] convert(::Type{Irrational{:log2π}}, ::Int64) at ./number.jl:7
[2] one(::Type{Irrational{:log2π}}) at ./number.jl:276
[3] one(::Irrational{:log2π}) at ./number.jl:277
[4] (::Zygote.var"#603#604"{Float64,Irrational{:log2π}})(::Float64) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/lib/number.jl:29
[5] (::Zygote.var"#1590#back#605"{Zygote.var"#603#604"{Float64,Irrational{:log2π}}})(::Float64) at /home/ubuntu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[6] normlogpdf at /home/ubuntu/.julia/packages/StatsFuns/CXyCV/src/distrs/norm.jl:29 [inlined]
[7] (::typeof((normlogpdf)))(::Float64) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[8] normlogpdf at /home/ubuntu/.julia/packages/StatsFuns/CXyCV/src/distrs/norm.jl:41 [inlined]
[9] (::typeof((normlogpdf)))(::Float64) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[10] #1754 at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/lib/broadcast.jl:142 [inlined]
[11] #3 at ./generator.jl:36 [inlined]
[12] iterate at ./generator.jl:47 [inlined]
[13] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof((normlogpdf)),2},Array{Float64,2}}},Base.var"#3#4"{Zygote.var"#1754#1761"}}) at ./array.jl:665
[14] map at ./abstractarray.jl:2154 [inlined]
[15] (::Zygote.var"#1753#1760"{Tuple{Array{Float64,2},Array{Float64,2},Array{Float64,2}},Val{4},Array{typeof((normlogpdf)),2}})(::Array{Float64,2}) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/lib/broadcast.jl:142
[16] #4425#back at /home/ubuntu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[17] (::Zygote.var"#174#175"{Zygote.var"#4425#back#1764"{Zygote.var"#1753#1760"{Tuple{Array{Float64,2},Array{Float64,2},Array{Float64,2}},Val{4},Array{typeof((normlogpdf)),2}}},Tuple{NTuple{4,Nothing},Tuple{Nothing}}})(::Array{Float64,2}) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:182
[18] #347#back at /home/ubuntu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[19] broadcasted at ./broadcast.jl:1238 [inlined]
[20] lpdf_gmm_sf at /home/ubuntu/repo/TuringBnpBenchmarks/dev/Benchmark_BnpUtil/benchmark_methods.jl:34 [inlined]
[21] (::typeof((lpdf_gmm_sf)))(::Float64) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[22] #46 at ./REPL[30]:2 [inlined]
[23] (::typeof((#46)))(::Float64) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[24] (::Zygote.var"#49#50"{Zygote.Params,Zygote.Context,typeof((#46))})(::Float64) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:179
[25] gradient(::Function, ::Zygote.Params) at /home/ubuntu/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:55
[26] top-level scope at REPL[30]:1

A little confused because errors aren't thrown when I use normlogpdf in a Turing model with an AD-based inference algorithm.

Create feature comparisons table

A features table will also be compiled for the various PPLs. These may include:

  • support / workarounds for missing data, with MWE.

  • support / workarounds for ragged arrays, with MWE.

  • support / workarounds for inference of discrete parameters, with MWE

    • Nimble and Pyro supports direct inference on discrete parameters. The recommended workaround in other PPL's is marginalizing discrete parameters, but this is not always possible.
  • support for automatic differentiation

  • Customizability

    • e.g. For MCMC, using a custom (user-provided) implementation to update a subset of model parameters, and use default update mechanisms (Metropolis-within-Gibbs or HMC) for the other parameters.
  • support for HMC, Metropolis-within-Gibbs, ADVI / BBVI, and auto-tuning for each PPL.

  • The table below is an example of what the Feature Comparisons table could look like.

Turing STAN Pyro Nimble TFP
Supports inference for discrete parameters
  • TODO
No (workaround) Yes Yes No (workaround)
Supports missing data
  • TODO
  • TODO
  • TODO
Yes No (workaround)
Supports AD
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
Supports customization of MCMC
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
Supports HMC
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
Supports NUTS
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
Supports ADVI
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
etc.
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO

Create summary table of all GP benchmarks

The GP models benchmark for a given sample size can be summarized in the table below, with the cells being compile times (if applicable) and execution times. Timings for forward passes and gradient computations of the models will also be computed if high-level API is available and exposed. The appropriate accompanying data visualizations (line plots, surface plots) will be included to make the results more interpretable. Links will be included to MWEs.

  • Make a dynamic table which changes N. (e.g. N=50, 100, 200, 400, 800).
  • Make interactive line plots where
    • x-axis is N, and y-axis is execution time or compile time
    • be able to select which implementations / PPL to show on the plot
  • Demo how GPs can be integrated & used as a building block within a probabilistic model.
Model (PPL), sample size = N Collapsed Gibbs SMC Particle Gibbs HMC/NUTS (SB) ADVI (SB)
GP (Turing)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
GP (STAN)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
GP (Nimble)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
GP (Pyro)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
GP (TFP) NA NA NA
  • TODO
  • TODO
LVGP (Turing)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
LVGP (STAN)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
LVGP (Nimble)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
LVGP (Pyro)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
LVGP (TFP) NA NA NA
  • TODO
  • TODO

Notes:

  • Prefer doing the experiments by column as opposed to by row in this table. This way, if we are run out of time, we have benchmarks across PPLs, for a few implementations; and not benchmarks across a few implementations for one or two PPLs.
  • LVGP refers to latent variable GP. GP refer to vanilla (full-rank) Gaussian process.

Create summary table of all DP benchmarks

The DP models benchmark for a given sample size can be summarized in the table below, with the cells being compile times (if applicable) and execution times. Timings for forward passes and gradient computations of the models will also be computed if high-level API is available and exposed. The appropriate accompanying data visualizations (line plots, surface plots) will be included to make the results more interpretable. Links will be included to MWEs.

  • Make a dynamic table which changes N. (e.g. N=50, 100, 200, 400, 800, 1600).
  • Make interactive line plots where
    • x-axis is N, and y-axis is execution time or compile time
    • be able to select which implementations / PPL to show on the plot
Model (PPL), sample size = N Collapsed Gibbs SMC Particle Gibbs HMC/NUTS (SB) ADVI (SB)
DP mixture of Gaussians (Turing)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
DP mixture of Gaussians (STAN) NA NA NA
  • TODO
  • TODO
DP mixture of Gaussians (Nimble)
  • TODO
  • TODO
  • TODO
NA NA
DP mixture of Gaussians (Pyro) NA
  • TODO
NA
  • TODO
  • TODO
DP mixture of Gaussians (TFP) NA NA NA
  • TODO
  • TODO
DP mixture of non-Gaussians (Turing)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
DP mixture of non-Gaussians (STAN) NA NA NA
  • TODO
  • TODO
DP mixture of non-Gaussians (Nimble)
  • TODO
  • TODO
  • TODO
  • TODO
  • TODO
DP mixture of non-Gaussians (Pyro) NA
  • TODO
  • TODO
  • TODO
  • TODO
DP mixture of non-Gaussians (TFP) NA NA NA
  • TODO
  • TODO

Notes:

  • Prefer doing the experiments by column as opposed to by row in this table. This way, if we are run out of time, we have benchmarks across PPLs, for a few implementations; and not benchmarks across a few implementations for one or two PPLs.
  • For now, for fix truncation level for number of stick-breaking components. Time permitting, see if a benchmark can be done to include varying truncation levels.

Poor inference for CRP DP GMM via SMC, PG, IS

I simulated some data and fit the following model to the data. I get poor inference. i.e., the results don't really match up with the simulation truth. Particularly, I tend to learn 4 to 5 clusters when my data has 4 almost-equally-sized clusters, but there are usually only 1 or 2 big (dominating) clusters in the posterior inference.

So, my first question is am I abusing the API? In my DP gaussian mixtures of location and scale, I use a base measure (H) which is Normal x InverseGamma (two independent distributions) for the location (mu) and scale (sigma). Am I doing this correctly? (It runs, but I suspect I'm doing something outside the intended use.)

My second question is, if the model is implemented correctly, what might be the cause for poor inference? Admittedly, I'm not familiar with SMC/PC. But does increasing the number of particles generally lead to better inference?

# DP GMM model under CRP construction
@model dp_gmm_crp(y) = begin
    nobs = length(y)
    
    alpha ~ Gamma(1, 0.1)  # mean = a*b
    rpm = DirichletProcess(alpha)
    
    # Base measure.
    H = arraydist([Normal(0, 3), InverseGamma(2, 0.05)])  # is this OK?
    
    # Latent assignment.
    z = tzeros(Int, nobs)
    
    # Locations and scales of infinitely many clusters.
    mu_sigma = TArray(Vector{Float64}, 0)  # is this OK?
    
    for i in 1:nobs
        # Number of clusters.
        K = maximum(z)
        n = Vector{Int}([sum(z .== k) for k in 1:K])
        
        # Sample cluster label.
        z[i] ~ ChineseRestaurantProcess(rpm,  n)
        
        # Create a new cluster.
        if z[i] > K
            push!(mu_sigma, [0.0, 0.1])  # is this OK?
            mu_sigma[z[i]] ~ H  # is this OK?
        end
        
        # Sampling distribution.
        mu, sigma = mu_sigma[z[i]]  # is this OK?
        y[i] ~ Normal(mu, sigma)
    end
end
;
# Set random seed for reproducibility
Random.seed!(0);

# Sample from posterior
@time chain = begin
    burn = 2000  # NOTE: The burn in is also returned. Discard manually.
    n_samples = 1000
    iterations = burn + n_samples

    sample(dp_gmm_crp(y), SMC(), iterations)
    # sample(dp_gmm_crp(y), IS(), iterations)
    # sample(dp_gmm_crp(y), Gibbs(PG(5, :z), PG(5, :mu_sigma)), iterations)
end;

Here is the complete notebook.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.