Skip to content

Commit

Permalink
Added utility functions to work with Gen models
Browse files Browse the repository at this point in the history
  • Loading branch information
hsm207 committed May 1, 2020
1 parent 1a44552 commit 97cc627
Showing 1 changed file with 50 additions and 3 deletions.
53 changes: 50 additions & 3 deletions src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ using HTTP, CSV
using JuMP, Ipopt, ForwardDiff
using Distributions
using LinearAlgebra
using Gen, RCall, Statistics, DataFrames

import StatsBase

export VariableSpecification, get_data, quap
export VariableSpecification, get_data, quap, get_posterior_samples, summarize_posterior_samples

build_url(filename) = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/$(filename).csv"
retrieve_file(url) = HTTP.get(url).body |> IOBuffer |> CSV.read
Expand All @@ -15,7 +16,7 @@ get_data(filename) = build_url(filename) |> retrieve_file
struct VariableSpecification
lower_bound::Float64
upper_bound::Float64
prior::Distribution
prior::Distributions.Distribution
end

calculate_covariance_matrix(f, optimal_points) = begin
Expand Down Expand Up @@ -45,4 +46,50 @@ function quap(objective_fn, vars_specs)
end

StatsBase.cov2cor(C::AbstractMatrix) = StatsBase.cov2cor(C, diag(C) .|> sqrt)
end

function do_inference(model, X, Y, amount_of_computation, params)

observations = Gen.choicemap()
for (i, y) in enumerate(Y)
observations[(:y, i)] = y
end

trace, = generate(model, (X,), observations)

for i = 1:amount_of_computation
trace, = mh(trace, Gen.select(params...))
end

return trace
end

function get_posterior_samples(model, sample_size, computation_budget, X, Y, params)
results = Array{Array{Float64}}(undef, sample_size)

Threads.@threads for i in 1:sample_size
trace = do_inference(model, X, Y, computation_budget, params)
results[i] = [trace[param] for param in params]
end

return hcat(results...)'
end

function summarize_posterior_samples(samples, params)
a = map(eachcol(samples)) do col
[mean(col), std(col)]
end |>
m->hcat(m...)'

b = R"""
require(rethinking)
apply($samples, 2, PI)
""" |>
rcopy |>
m->m'

m = hcat(params, a, b)

return DataFrame(m, [:param, :mean, :std, Symbol("5%"), Symbol("95%")])
end

end

0 comments on commit 97cc627

Please sign in to comment.