Skip to content

Commit

Permalink
Functions to compute lppd from Gen model
Browse files Browse the repository at this point in the history
  • Loading branch information
hsm207 committed May 2, 2020
1 parent 97cc627 commit 80f31b3
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ using Gen, RCall, Statistics, DataFrames

import StatsBase

export VariableSpecification, get_data, quap, get_posterior_samples, summarize_posterior_samples
export get_data,
get_posterior_samples,
lppd,
quap,
summarize_posterior_samples,
VariableSpecification

build_url(filename) = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/$(filename).csv"
retrieve_file(url) = HTTP.get(url).body |> IOBuffer |> CSV.read
Expand Down Expand Up @@ -92,4 +97,40 @@ function summarize_posterior_samples(samples, params)
return DataFrame(m, [:param, :mean, :std, Symbol("5%"), Symbol("95%")])
end

function observation_log_probabilities(params::DataFrameRow, choices, X, model)

# assume the Y address is (:y, i) where i is the i-th observation
N = filter(k->k isa Tuple && k[1] == :y, keys(choices.leaf_nodes)) |> length

for (k, v) in zip(keys(params), values(params))
choices[k] = v
end
trace, _ = generate(model, (X,), choices)

[project(trace, Gen.select((:y, i))) for i in 1:N]

end

function lppd(posterior_params::DataFrame, X::AbstractArray, Y::Vector, model)

choices = Gen.choicemap()

# Y is a constant regardless of choice of parameters
for (i, y) in enumerate(Y)
choices[(:y, i)] = y
end

S, _ = size(posterior_params)

# S x N, where S number of parameter samples, N number of observations
obs_log_pros = map(eachrow(posterior_params)) do params
observation_log_probabilities(params, choices, X, model)
end |>
m->hcat(m...)'

# 1 x N
map(eachcol(obs_log_pros)) do log_probs
(log_probs .|> exp |> sum |> log) - log(S)
end
end
end

0 comments on commit 80f31b3

Please sign in to comment.