Git Product home page Git Product logo

Comments (4)

cscherrer avatar cscherrer commented on June 2, 2024

Thanks for letting me know about this. The stack trace from xform includes this line:

 [2] xform(d::Distributions.Dirichlet{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}, _data::NamedTuple{(), Tuple{}})
   @ Soss ~/git/Soss.jl/src/primitives/xform.jl:79

Following that takes you to

function xform(d, _data::NamedTuple)
    if hasmethod(support, (typeof(d),))
        return asTransform(support(d)) 
    end

    error("Not implemented:\nxform($d)")
end

The problem here is that Distributions.Dirichlet has no support method, so it falls through and throws the error. So you're right that the fix is to add this method.

xform is kind of a legacy name. It was originally going to be transform, but that name was already taken in TransformVariables.jl. But since this was built, I'm realizing this should have just been called as, since it has the same functionality as that function from TransformVariables. Docs on that are here:
https://tamaspapp.eu/TransformVariables.jl/dev/#The-as-constructor-and-aggregations

Also, in the current setup, the _data argument is only used when you have nested models. I'll be cleaning up the dispatch patterns for this, but for a quick fix let's just go with it.

Anyway, the missing method is

Soss.xform(d::Dists.Dirichlet, _data::NamedTuple) = TransformVariables.UnitSimplex(length(d.alpha))

But this doesn't fix everything, because you still have

z ~ For(N) do _ Distributions.Categorical(w) end

This is discrete, so there's no way to set up a bijection to the reals. This is not Soss-specific, you'd have the same issue in Turing or Stan with HMC.

We'll be adding ways to make this easier in MeasureTheory, but for now in Distributions, say you have

julia> μ = rand(Normal() |> iid(3))
3-element Vector{Float64}:
  1.0510484386874308
 -0.8007745046155319
  0.48629964893183536

Then you can do using FillArrays, MappedArrays and then

julia> paramvec = mappedarray(μ) do μj begin (Fill(μj, 2), 1.) end end
3-element mappedarray(var"#7#8"(), ::Vector{Float64}) with eltype Tuple{Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}:
 (Fill(1.0510484386874308, 2), 1.0)
 (Fill(-0.8007745046155319, 2), 1.0)
 (Fill(0.48629964893183536, 2), 1.0)

These are the mixture components, which you can combine like

julia> Dists.MixtureModel(Dists.MvNormal, paramvec)
MixtureModel{Distributions.MvNormal}(K = 3)
components[1] (prior = 0.3333): Distributions.MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(1.0510484386874308, 2)
Σ: [1.0 0.0; 0.0 1.0]
)

components[2] (prior = 0.3333): Distributions.MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(-0.8007745046155319, 2)
Σ: [1.0 0.0; 0.0 1.0]
)

components[3] (prior = 0.3333): Distributions.MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(0.48629964893183536, 2)
Σ: [1.0 0.0; 0.0 1.0]
)

This uses equal weights, but you can change that by adding another parameter.

Anyway, this works but it's not pretty. We're working on making it much easier to stay in MeasureTheory for all of this.

using MeasureTheory
using Soss
using SampleChainsDynamicHMC
import Distributions
using FillArrays
using LinearAlgebra

m = @model N begin
    σ0 ~ Lebesgue(ℝ)
    μ0 ~ Lebesgue(ℝ)
    α ~ Lebesgue(ℝ₊)
    K = 2
    μ ~ Normal(μ0, σ0) |> iid(K)
    w ~ Distributions.Dirichlet(K, abs(α))
    xdist = Dists.MixtureModel(Dists.Normal, μ, w)
    x ~ Dists.MatrixReshaped(Dists.Product(Fill(xdist, K*N)), K, N)
end

using TransformVariables
const TV = TransformVariables
Soss.xform(d::Dists.Dirichlet, _data::NamedTuple) = TV.UnitSimplex(length(d.alpha))

prior_data = predict(m(N=30), (N=30, σ0=1., μ0=0., α=1.))

# data generation with assumption on μ and w
predx = predictive(m, , :w)
data = predict(m(N=30), (μ=[-3.5, 0.0], w=[0.5, 0.5]))

# estimating the posterior
posterior = m(N=30)|(x=data.x,)

sample(posterior, dynamichmc())

from soss.jl.

gzagatti avatar gzagatti commented on June 2, 2024

@cscherrer many thanks for the detailed answer.

I made some further explorations on my own and had a few issues:

  1. I went through the transform documentation. I do not quite understand the implementation. For instance with asℝ, I do not quite get why they are using the exponential distribution.

  2. for some reason Fill inside of the proposed model tends to replicate a single draw from the mixed distribution K*N rather than performing K*N draws. When I plotted the samples from prior_data I basically got two points. If I switch to base fill the model works as expected. The problem with the fixed model is that it does not ensure that both components are the same for each row.

  3. I attempted a different variation of the model as following. Since the Distributions.jl package does not have a definition for the Product of a multivariate, I implemented a basic version to get the work done:

    struct Product{
    	    S<:Distributions.ValueSupport,
    	    T<:Distributions.MultivariateDistribution{S},
    	    V<:AbstractVector{T},
           } <: Distributions.MultivariateDistribution{S}
        v::V
        function Product(v::V) where
    	    V<:AbstractVector{T} where
    	    T<:Distributions.MultivariateDistribution{S} where
    	    S<:Distributions.ValueSupport
    	    return new{S, T, V}(v)
        end
    end
    
    Base.length(d::Product) = length(d.v)
    function Base.eltype(::Type{<:Product{S, T}}) where {S<:Distributions.ValueSupport, T<:Distributions.MultivariateDistribution{S}}
        return eltype(T)
    end
    
    function _rand!(rng::Distributions.AbstractRNG, d::Product, x::AbstractVector)
        broadcast!(dn->rand(rng, dn), x, d.v)
    end
    
    function _logpdf(d::Product, x::AbstractVector)
        sum(n -> logpdf(d.v[n], x[n]), 1:length(d))
    end
    
    function Distributions.rand(rng::Distributions.AbstractRNG, s::Product)
        _rand!(rng, s, Vector{Vector{eltype(s)}}(undef, length(s)))
    end

    I then redifined the model as following:

    components = mappedarray(μ) do μk begin (Fill(μk, 2), 1.) end end
    mixture = Dists.MixtureModel(Dists.MvNormal, components, w)
    x ~ Product(fill(mixture, N))

    Unfortunately, I still have problems with running the chain. It complains the method _logpdf is not implemented even though I did implement it.

  4. I could not use the function sample. Issue #293 was raised with a similar problem.

  5. I am not quite sure what you meant by:

    This is discrete, so there's no way to set up a bijection to the reals. This is not Soss-specific, you'd have the same issue in Turing or Stan with HMC.

Do you have plans to add additional examples to the docs sometime soon? I know the project is developing quite fast at the moment. Please do let me know if there is a need for help. I am currently learning about these models and it would be good practice to write a few examples.

from soss.jl.

cscherrer avatar cscherrer commented on June 2, 2024

@cscherrer many thanks for the detailed answer.

I made some further explorations on my own and had a few issues:

1. I went through the transform documentation. I do not quite understand the implementation. For instance with `asℝ`, I do not quite get why they are using the exponential distribution.

Sorry, I don't understand. Can you point me to a line?

2. for some reason `Fill` inside of the proposed model tends to replicate a single draw from the mixed distribution `K*N` rather than performing `K*N` draws. When I plotted the samples from `prior_data` I basically got two points. If I switch to base `fill` the model works as expected. The problem with the fixed model is that it does not ensure that both components are the same for each row.

Ah ok, I think the issue here is that there's no way to tell FillArrays that sampling is nondeterministic. Ideally Distributions would account for this, but it seems they don't. I guess this is the point on filldist in DistributionsAD. And come to think of it, you probably need that anyway since it will make gradients much more efficient for Distributions.

In MeasureTheory we'll have all of this built in. Currently we don't have any custom AD, but the implementations are also much simpler, so AD should have an easier time of it. We'll be adding more optimized methods as we go.

3. I attempted a different variation of the model as following. Since the `Distributions.jl` package does not have a definition for the [`Product` of a multivariate](https://github.com/JuliaStats/Distributions.jl/blob/59df675409a7e2490e4a45edd32c0267df435c55/src/multivariate/product.jl), I implemented a basic version to get the work done:
   ```julia
   struct Product{
   	    S<:Distributions.ValueSupport,
   	    T<:Distributions.MultivariateDistribution{S},
   	    V<:AbstractVector{T},
          } <: Distributions.MultivariateDistribution{S}
       v::V
       function Product(v::V) where
   	    V<:AbstractVector{T} where
   	    T<:Distributions.MultivariateDistribution{S} where
   	    S<:Distributions.ValueSupport
   	    return new{S, T, V}(v)
       end
   end
   
   Base.length(d::Product) = length(d.v)
   function Base.eltype(::Type{<:Product{S, T}}) where {S<:Distributions.ValueSupport, T<:Distributions.MultivariateDistribution{S}}
       return eltype(T)
   end
   
   function _rand!(rng::Distributions.AbstractRNG, d::Product, x::AbstractVector)
       broadcast!(dn->rand(rng, dn), x, d.v)
   end
   
   function _logpdf(d::Product, x::AbstractVector)
       sum(n -> logpdf(d.v[n], x[n]), 1:length(d))
   end
   
   function Distributions.rand(rng::Distributions.AbstractRNG, s::Product)
       _rand!(rng, s, Vector{Vector{eltype(s)}}(undef, length(s)))
   end
   ```
   
   I then redifined the model as following:
   ```julia
   components = mappedarray(μ) do μk begin (Fill(μk, 2), 1.) end end
   mixture = Dists.MixtureModel(Dists.MvNormal, components, w)
   x ~ Product(fill(mixture, N))
   ```
   
   Unfortunately, I still have problems with running the chain. It complains the method `_logpdf` is not implemented even though I did implement it.

Thanks for letting me know about this, I'll have a look and see if I can work it out.

In general, I think there are a lot of fundamental problems with Distributions, especially when it comes to PPL. Making this better is a lot of the motivation behind MeasureTheory. It's not yet a full workaround, but most of my energy this year has been directed toward this.

4. I could not use the function `sample`.  Issue [sample(...) does not work #293](https://github.com/cscherrer/Soss.jl/issues/293) was raised with a similar problem.

Thanks for letting me know about this. When error are you getting? I'll need to be able to reproduce the problem before I can make progress on it.

5. I am not quite sure what you meant by:
   > This is discrete, so there's no way to set up a bijection to the reals. This is not Soss-specific, you'd have the same issue in Turing or Stan with HMC.

Hamiltonian Monte Carlo (HMC) was popularized by the Stan language. It's a great way to do inference, but it only works when the sample space is unconstrained Euclidean space.

The standard way to work around this is to marginalize of the discrete parameters, and set up a bijection between the sample space and ℝⁿ.

Do you have plans to add additional examples to the docs sometime soon? I know the project is developing quite fast at the moment. Please do let me know if there is a need for help. I am currently learning about these models and it would be good practice to write a few examples.

This would be great!! Yes, we definitely need documentation, examples, tutorials, etc. The only limitation here is that I'm stretched in a few different directions, so it's hard to get everything done at once.

from soss.jl.

gzagatti avatar gzagatti commented on June 2, 2024

Thanks for the help again.

I have been studying the topic in more details and I developed a better understanding for transform. The TransformVariables.jl seems to be similar to Bijectors.jl.

I have created a gist with the Gaussian mixture example using different options that I have played around. The gist is a Pluto notebook so you should be able to replicate with the exact environment I am using. The version that is not commented out runs without any issues except for the last command that calls sample. It complains that this function is not defined.

As I get more familiar with PPLs, I will try to write some examples, add them to the documentation and open a PR with them.

from soss.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.