From fa8def459b77ff0110b1a41f0e719172c54dc3bd Mon Sep 17 00:00:00 2001 From: hsm207 Date: Wed, 14 Oct 2020 06:38:55 +0800 Subject: [PATCH] Notebook on mysterious HMC behavior --- book/16_Madness/hard_chain.ipynb | 1505 ++++++++++++++++++++++++++++++ 1 file changed, 1505 insertions(+) create mode 100644 book/16_Madness/hard_chain.ipynb diff --git a/book/16_Madness/hard_chain.ipynb b/book/16_Madness/hard_chain.ipynb new file mode 100644 index 0000000..f8eecc5 --- /dev/null +++ b/book/16_Madness/hard_chain.ipynb @@ -0,0 +1,1505 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "cd(\"..\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Packages" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "include(\"../src/Utils.jl\")\n", + "using .Utils" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "using Pipe\n", + "using Query\n", + "using VegaLite\n", + "using DataFrames\n", + "using Turing\n", + "using StatsPlots" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "n = 1000\n", + "\n", + "h = rand(Uniform(0, 100), n)\n", + "w = exp.(5 .+ 2 .* h) .+ rand(Normal(0, 1), n);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Chain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This model converges very fast:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Found initial step size\n", + "│ ϵ = 0.00078125\n", + "└ @ Turing.Inference /home/user/.julia/packages/Turing/G7n2S/src/inference/hmc.jl:625\n", + "\u001b[32mSampling: 100%|█████████████████████████████████████████| Time: 0:02:13\u001b[39m\n" + ] + }, + { + "data": { + "text/plain": [ + "Chains MCMC chain (500×14×1 Array{Float64,3}):\n", + "\n", + "Iterations = 1:500\n", + "Thinning interval = 1\n", + "Chains = 1\n", + "Samples per chain = 500\n", + "parameters = α, β\n", + "internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth\n", + "\n", + "Summary Statistics\n", + " \u001b[0m\u001b[1m parameters \u001b[0m \u001b[0m\u001b[1m mean \u001b[0m \u001b[0m\u001b[1m std \u001b[0m \u001b[0m\u001b[1m naive_se \u001b[0m \u001b[0m\u001b[1m mcse \u001b[0m \u001b[0m\u001b[1m ess \u001b[0m \u001b[0m\u001b[1m rhat \u001b[0m \u001b[0m\n", + " \u001b[0m\u001b[90m Symbol \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\n", + " \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m\n", + " \u001b[0m α \u001b[0m \u001b[0m 5.0140 \u001b[0m \u001b[0m 0.0632 \u001b[0m \u001b[0m 0.0028 \u001b[0m \u001b[0m 0.0066 \u001b[0m \u001b[0m 84.3985 \u001b[0m \u001b[0m 1.0235 \u001b[0m \u001b[0m\n", + " \u001b[0m β \u001b[0m \u001b[0m 1.9998 \u001b[0m \u001b[0m 0.0011 \u001b[0m \u001b[0m 0.0000 \u001b[0m \u001b[0m 0.0001 \u001b[0m \u001b[0m 103.0197 \u001b[0m \u001b[0m 1.0301 \u001b[0m \u001b[0m\n", + "\n", + "Quantiles\n", + " \u001b[0m\u001b[1m parameters \u001b[0m \u001b[0m\u001b[1m 2.5% \u001b[0m \u001b[0m\u001b[1m 25.0% \u001b[0m \u001b[0m\u001b[1m 50.0% \u001b[0m \u001b[0m\u001b[1m 75.0% \u001b[0m \u001b[0m\u001b[1m 97.5% \u001b[0m \u001b[0m\n", + " \u001b[0m\u001b[90m Symbol \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\n", + " \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m\n", + " \u001b[0m α \u001b[0m \u001b[0m 4.9045 \u001b[0m \u001b[0m 4.9634 \u001b[0m \u001b[0m 5.0140 \u001b[0m \u001b[0m 5.0554 \u001b[0m \u001b[0m 5.1452 \u001b[0m \u001b[0m\n", + " \u001b[0m β \u001b[0m \u001b[0m 1.9978 \u001b[0m \u001b[0m 1.9991 \u001b[0m \u001b[0m 1.9997 \u001b[0m \u001b[0m 2.0005 \u001b[0m \u001b[0m 2.0019 \u001b[0m \u001b[0m\n" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@model function model1(height, weight)\n", + " N = length(height)\n", + " \n", + " α ~ Normal(0, 10)\n", + " β ~ Normal(0, 10)\n", + " \n", + " for i in 1:N\n", + " μᵢ = α + β * height[i]\n", + " weight[i] ~ LogNormal(μᵢ, 1)\n", + " end\n", + "end\n", + "\n", + "\n", + "c1 = sample(model1(h, w), NUTS(), 1000)\n", + "c1" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "plot(c1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This model is very hard to converge and throws alot of numerical errors:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Warning: The current proposal will be rejected due to numerical error(s).\n", + "│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)\n", + "└ @ AdvancedHMC /home/user/.julia/packages/AdvancedHMC/P9wqk/src/hamiltonian.jl:47\n", + "┌ Warning: The current proposal will be rejected due to numerical error(s).\n", + "│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)\n", + "└ @ AdvancedHMC /home/user/.julia/packages/AdvancedHMC/P9wqk/src/hamiltonian.jl:47\n", + "┌ Warning: The current proposal will be rejected due to numerical error(s).\n", + "│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)\n", + "└ @ AdvancedHMC /home/user/.julia/packages/AdvancedHMC/P9wqk/src/hamiltonian.jl:47\n", + "┌ Warning: The current proposal will be rejected due to numerical error(s).\n", + "│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)\n", + "└ @ AdvancedHMC /home/user/.julia/packages/AdvancedHMC/P9wqk/src/hamiltonian.jl:47\n", + "┌ Warning: The current proposal will be rejected due to numerical error(s).\n", + "│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)\n", + "└ @ AdvancedHMC /home/user/.julia/packages/AdvancedHMC/P9wqk/src/hamiltonian.jl:47\n", + "┌ Warning: The current proposal will be rejected due to numerical error(s).\n", + "│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)\n", + "└ @ AdvancedHMC /home/user/.julia/packages/AdvancedHMC/P9wqk/src/hamiltonian.jl:47\n", + "┌ Warning: The current proposal will be rejected due to numerical error(s).\n", + "│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false)\n", + "└ @ AdvancedHMC /home/user/.julia/packages/AdvancedHMC/P9wqk/src/hamiltonian.jl:47\n", + "┌ Info: Found initial step size\n", + "│ ϵ = 9.765625e-5\n", + "└ @ Turing.Inference /home/user/.julia/packages/Turing/G7n2S/src/inference/hmc.jl:625\n", + "\u001b[32mSampling: 100%|█████████████████████████████████████████| Time: 0:02:09\u001b[39m\n" + ] + }, + { + "data": { + "text/plain": [ + "Chains MCMC chain (500×15×1 Array{Float64,3}):\n", + "\n", + "Iterations = 1:500\n", + "Thinning interval = 1\n", + "Chains = 1\n", + "Samples per chain = 500\n", + "parameters = α, β, σ\n", + "internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth\n", + "\n", + "Summary Statistics\n", + " \u001b[0m\u001b[1m parameters \u001b[0m \u001b[0m\u001b[1m mean \u001b[0m \u001b[0m\u001b[1m std \u001b[0m \u001b[0m\u001b[1m naive_se \u001b[0m \u001b[0m\u001b[1m mcse \u001b[0m \u001b[0m\u001b[1m ess \u001b[0m \u001b[0m\u001b[1m rhat \u001b[0m \u001b[0m\n", + " \u001b[0m\u001b[90m Symbol \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\n", + " \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m\n", + " \u001b[0m α \u001b[0m \u001b[0m 5.0000 \u001b[0m \u001b[0m 0.0000 \u001b[0m \u001b[0m 0.0000 \u001b[0m \u001b[0m 0.0000 \u001b[0m \u001b[0m 2.4406 \u001b[0m \u001b[0m 1.3190 \u001b[0m \u001b[0m\n", + " \u001b[0m β \u001b[0m \u001b[0m 2.0000 \u001b[0m \u001b[0m 0.0000 \u001b[0m \u001b[0m 0.0000 \u001b[0m \u001b[0m 0.0000 \u001b[0m \u001b[0m 25.4092 \u001b[0m \u001b[0m 1.0736 \u001b[0m \u001b[0m\n", + " \u001b[0m σ \u001b[0m \u001b[0m 0.0011 \u001b[0m \u001b[0m 0.0000 \u001b[0m \u001b[0m 0.0000 \u001b[0m \u001b[0m 0.0000 \u001b[0m \u001b[0m 3.3592 \u001b[0m \u001b[0m 1.1943 \u001b[0m \u001b[0m\n", + "\n", + "Quantiles\n", + " \u001b[0m\u001b[1m parameters \u001b[0m \u001b[0m\u001b[1m 2.5% \u001b[0m \u001b[0m\u001b[1m 25.0% \u001b[0m \u001b[0m\u001b[1m 50.0% \u001b[0m \u001b[0m\u001b[1m 75.0% \u001b[0m \u001b[0m\u001b[1m 97.5% \u001b[0m \u001b[0m\n", + " \u001b[0m\u001b[90m Symbol \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\u001b[90m Float64 \u001b[0m \u001b[0m\n", + " \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m \u001b[0m\n", + " \u001b[0m α \u001b[0m \u001b[0m 4.9999 \u001b[0m \u001b[0m 5.0000 \u001b[0m \u001b[0m 5.0000 \u001b[0m \u001b[0m 5.0000 \u001b[0m \u001b[0m 5.0000 \u001b[0m \u001b[0m\n", + " \u001b[0m β \u001b[0m \u001b[0m 2.0000 \u001b[0m \u001b[0m 2.0000 \u001b[0m \u001b[0m 2.0000 \u001b[0m \u001b[0m 2.0000 \u001b[0m \u001b[0m 2.0000 \u001b[0m \u001b[0m\n", + " \u001b[0m σ \u001b[0m \u001b[0m 0.0011 \u001b[0m \u001b[0m 0.0011 \u001b[0m \u001b[0m 0.0011 \u001b[0m \u001b[0m 0.0011 \u001b[0m \u001b[0m 0.0011 \u001b[0m \u001b[0m\n" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@model function model2(height, weight)\n", + " N = length(height)\n", + " \n", + " σ ~ Exponential(1)\n", + " α ~ Normal(0, 10)\n", + " β ~ Normal(0, 10)\n", + " \n", + " for i in 1:N\n", + " μᵢ = α + β * height[i]\n", + " weight[i] ~ LogNormal(μᵢ, σ)\n", + " end\n", + "end\n", + "\n", + "\n", + "c2 = sample(model2(h, w), NUTS(), 1000)\n", + "c2" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "plot(c2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "How to set the prior of $\\sigma$ so that chain converges to the correct answer quickly?" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.5.1", + "language": "julia", + "name": "julia-1.5" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.5.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}