Skip to content

Commit

Permalink
Clarify log probability
Browse files Browse the repository at this point in the history
  • Loading branch information
hsm207 committed May 9, 2020
1 parent 78fa3dc commit 76a7453
Showing 1 changed file with 263 additions and 0 deletions.
263 changes: 263 additions & 0 deletions book/07_Information_Theory/calculate_log_probability.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Packages"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"using Gen\n",
"using Distributions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Specification"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$$Y \\text{~} N(\\alpha + \\beta X, \\sigma)$$\n",
"\n",
"where:\n",
"\n",
"* $\\alpha$ follows $N(0, 1)$\n",
"* $\\beta$ follows $N(5, 5)$\n",
"* $\\sigma$ follows $exp(1)$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Parameter Values"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"y = 10.0\n",
"x = 3\n",
"α = 0.2\n",
"β = 7.0\n",
"σ = 1\n",
";"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Manual Calculation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"l_parameters = -4.5473149788434455\n",
"ll_data = -63.638938533204666\n",
"log_probability_manual = -68.18625351204811\n"
]
}
],
"source": [
"l_parameters = Distributions.logpdf(Normal(0, 1), α) +\n",
" Distributions.logpdf(Normal(5, 5), β) +\n",
" Distributions.logpdf(Exponential(1), σ)\n",
"\n",
"ll_data = Distributions.logpdf(Normal(α + β * x, σ), y)\n",
"\n",
"log_probability_manual = l_parameters + ll_data\n",
"\n",
"@show l_parameters\n",
"@show ll_data\n",
"@show log_probability_manual\n",
";"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using Gen"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"@gen function model(x)\n",
" α = @trace(normal(0, 1), :α)\n",
" β = @trace(normal(5, 5), :β)\n",
" σ = @trace(exponential(1), :σ)\n",
" \n",
" μ = α + β * x\n",
" \n",
" @trace(normal(μ, σ), :y)\n",
"end\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"choices = choicemap(:α => α,\n",
" :β => β,\n",
" :σ => σ,\n",
" :y => y)\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"trace, _ = generate(model, (x, ), choices);"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-68.18625351204811"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"log_probability_gen = get_score(trace)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"@assert(log_probability_manual == log_probability_gen)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check individual parameter score:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(k, v.score) = (:α, -0.9389385332046727)\n",
"(k, v.score) = (:σ, -1.0)\n",
"(k, v.score) = (:y, -63.638938533204666)\n",
"(k, v.score) = (:β, -2.6083764456387732)\n"
]
}
],
"source": [
"for (k, v) in trace.trie.leaf_nodes\n",
" @show k, v.score\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-67.18625351204811"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"project(trace, select(:α, :β, :y)) "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-67.18625351204811"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[-0.9389385332046727 -2.6083764456387732 -63.638938533204666] |> sum"
]
}
],
"metadata": {
"@webio": {
"lastCommId": null,
"lastKernelId": null
},
"kernelspec": {
"display_name": "Julia 1.4.1",
"language": "julia",
"name": "julia-1.4"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.4.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

0 comments on commit 76a7453

Please sign in to comment.