-
-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemention of Gaussian mixture model fails when sampling the posterior #296
Comments
Thanks for letting me know about this. The stack trace from [2] xform(d::Distributions.Dirichlet{Float64, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}, _data::NamedTuple{(), Tuple{}})
@ Soss ~/git/Soss.jl/src/primitives/xform.jl:79 Following that takes you to function xform(d, _data::NamedTuple)
if hasmethod(support, (typeof(d),))
return asTransform(support(d))
end
error("Not implemented:\nxform($d)")
end The problem here is that
Also, in the current setup, the Anyway, the missing method is Soss.xform(d::Dists.Dirichlet, _data::NamedTuple) = TransformVariables.UnitSimplex(length(d.alpha)) But this doesn't fix everything, because you still have z ~ For(N) do _ Distributions.Categorical(w) end This is discrete, so there's no way to set up a bijection to the reals. This is not Soss-specific, you'd have the same issue in Turing or Stan with HMC. We'll be adding ways to make this easier in MeasureTheory, but for now in Distributions, say you have julia> μ = rand(Normal() |> iid(3))
3-element Vector{Float64}:
1.0510484386874308
-0.8007745046155319
0.48629964893183536 Then you can do julia> paramvec = mappedarray(μ) do μj begin (Fill(μj, 2), 1.) end end
3-element mappedarray(var"#7#8"(), ::Vector{Float64}) with eltype Tuple{Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Float64}:
(Fill(1.0510484386874308, 2), 1.0)
(Fill(-0.8007745046155319, 2), 1.0)
(Fill(0.48629964893183536, 2), 1.0) These are the mixture components, which you can combine like julia> Dists.MixtureModel(Dists.MvNormal, paramvec)
MixtureModel{Distributions.MvNormal}(K = 3)
components[1] (prior = 0.3333): Distributions.MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(1.0510484386874308, 2)
Σ: [1.0 0.0; 0.0 1.0]
)
components[2] (prior = 0.3333): Distributions.MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(-0.8007745046155319, 2)
Σ: [1.0 0.0; 0.0 1.0]
)
components[3] (prior = 0.3333): Distributions.MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(0.48629964893183536, 2)
Σ: [1.0 0.0; 0.0 1.0]
) This uses equal weights, but you can change that by adding another parameter. Anyway, this works but it's not pretty. We're working on making it much easier to stay in MeasureTheory for all of this. using MeasureTheory
using Soss
using SampleChainsDynamicHMC
import Distributions
using FillArrays
using LinearAlgebra
m = @model N begin
σ0 ~ Lebesgue(ℝ)
μ0 ~ Lebesgue(ℝ)
α ~ Lebesgue(ℝ₊)
K = 2
μ ~ Normal(μ0, σ0) |> iid(K)
w ~ Distributions.Dirichlet(K, abs(α))
xdist = Dists.MixtureModel(Dists.Normal, μ, w)
x ~ Dists.MatrixReshaped(Dists.Product(Fill(xdist, K*N)), K, N)
end
using TransformVariables
const TV = TransformVariables
Soss.xform(d::Dists.Dirichlet, _data::NamedTuple) = TV.UnitSimplex(length(d.alpha))
prior_data = predict(m(N=30), (N=30, σ0=1., μ0=0., α=1.))
# data generation with assumption on μ and w
predx = predictive(m, :μ, :w)
data = predict(m(N=30), (μ=[-3.5, 0.0], w=[0.5, 0.5]))
# estimating the posterior
posterior = m(N=30)|(x=data.x,)
sample(posterior, dynamichmc()) |
@cscherrer many thanks for the detailed answer. I made some further explorations on my own and had a few issues:
Do you have plans to add additional examples to the docs sometime soon? I know the project is developing quite fast at the moment. Please do let me know if there is a need for help. I am currently learning about these models and it would be good practice to write a few examples. |
Sorry, I don't understand. Can you point me to a line?
Ah ok, I think the issue here is that there's no way to tell FillArrays that sampling is nondeterministic. Ideally Distributions would account for this, but it seems they don't. I guess this is the point on In MeasureTheory we'll have all of this built in. Currently we don't have any custom AD, but the implementations are also much simpler, so AD should have an easier time of it. We'll be adding more optimized methods as we go.
Thanks for letting me know about this, I'll have a look and see if I can work it out. In general, I think there are a lot of fundamental problems with Distributions, especially when it comes to PPL. Making this better is a lot of the motivation behind MeasureTheory. It's not yet a full workaround, but most of my energy this year has been directed toward this.
Thanks for letting me know about this. When error are you getting? I'll need to be able to reproduce the problem before I can make progress on it.
Hamiltonian Monte Carlo (HMC) was popularized by the Stan language. It's a great way to do inference, but it only works when the sample space is unconstrained Euclidean space. The standard way to work around this is to marginalize of the discrete parameters, and set up a bijection between the sample space and ℝⁿ.
This would be great!! Yes, we definitely need documentation, examples, tutorials, etc. The only limitation here is that I'm stretched in a few different directions, so it's hard to get everything done at once. |
Thanks for the help again. I have been studying the topic in more details and I developed a better understanding for transform. The I have created a gist with the Gaussian mixture example using different options that I have played around. The gist is a Pluto notebook so you should be able to replicate with the exact environment I am using. The version that is not commented out runs without any issues except for the last command that calls As I get more familiar with PPLs, I will try to write some examples, add them to the documentation and open a PR with them. |
I got quite excited about the
Soss.jl
presentation in JuliaConn 2021 that I decided to give it a go by implementing a mixture of Gaussian model. I borrowed the example from Turing.jl for comparison purposes. Please let me know if there is anything that could be improved.Unfortunately, I was not able to estimate the posterior distribution as an error is raised when I call
tr = xform(posterior)
. Apparently,xform
is not defined for the Dirichlet distribution.As I understand, not all distributions have been implemented directly in
Soss.jl
. If that is something within my capabilities, I could try to implement it. However, I have no clue whatxform
is doing and I have not been able to find a lot of documentation on it.In any case, thanks for putting the package together.
The text was updated successfully, but these errors were encountered: