Skip to content

Commit

Permalink
Utilities to compute WAIC
Browse files Browse the repository at this point in the history
  • Loading branch information
hsm207 committed May 9, 2020
1 parent 3c1d42a commit 702a1f6
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ using Gen, RCall, Statistics, DataFrames
import StatsBase

export get_data,
get_posterior_samples,
lppd,
get_posterior_samples,
get_prediction_log_probs,
lppd,
pwaic,
quap,
summarize_posterior_samples,
VariableSpecification
VariableSpecification,
waic

build_url(filename) = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/$(filename).csv"
retrieve_file(url) = HTTP.get(url).body |> IOBuffer |> CSV.read
Expand Down Expand Up @@ -98,7 +101,6 @@ function summarize_posterior_samples(samples, params)
end

function observation_log_probabilities(params::DataFrameRow, choices, X, model)

# assume the Y address is (:y, i) where i is the i-th observation
N = filter(k->k isa Tuple && k[1] == :y, keys(choices.leaf_nodes)) |> length

Expand All @@ -108,29 +110,44 @@ function observation_log_probabilities(params::DataFrameRow, choices, X, model)
trace, _ = generate(model, (X,), choices)

[project(trace, Gen.select((:y, i))) for i in 1:N]

end

function lppd(posterior_params::DataFrame, X::AbstractArray, Y::Vector, model)

function get_prediction_log_probs(posterior_params::DataFrame, X::AbstractArray, Y::Vector, model)
choices = Gen.choicemap()

# Y is a constant regardless of choice of parameters
for (i, y) in enumerate(Y)
choices[(:y, i)] = y
end

S, _ = size(posterior_params)

# S x N, where S number of parameter samples, N number of observations
obs_log_pros = map(eachrow(posterior_params)) do params
map(eachrow(posterior_params)) do params
observation_log_probabilities(params, choices, X, model)
end |>
m->hcat(m...)'
m->hcat(m...)'
end

function lppd(observed_log_scores::AbstractArray)
# observed_log_scores is S x N where S is number of parameter samples and
# N is number of observations
S, _ = size(observed_log_scores)

# 1 x N
map(eachcol(obs_log_pros)) do log_probs
map(eachcol(observed_log_scores)) do log_probs
(log_probs .|> exp |> sum |> log) - log(S)
end
end

function pwaic(observed_log_scores::AbstractArray)

# observed_log_scores is S x N where S is number of parameter samples and
# N is number of observations

# 1 x N
map(eachcol(observed_log_scores)) do log_probs
var(log_probs)
end
end

waic(observed_log_scores::AbstractArray) = lppd(observed_log_scores) - pwaic(observed_log_scores)
end

0 comments on commit 702a1f6

Please sign in to comment.