%cd ..
%load_ext autoreload
%autoreload 2
/home/runner/work/numpyro-doing-bayesian/numpyro-doing-bayesian
from __future__ import annotations

import arviz as az
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer.initialization import init_to_median
import numpyro_glm.utils.dist as dist_utils
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs
import pandas as pd
import seaborn as sns

numpyro.set_host_device_count(4)

Chapter 12: Bayesian Approaches to Testing a Point ("Null") Hypothesis

The Estimation Approach

Region of Practical Equivalence

A parameter value is declared to be not credible, or rejected,

if its entire ROPE lies outside of the 95% highest density interval (HDI) of the posterior distribution of that parameter.

A parameter value is declared to be accepted for practical purposes

if that value's ROPE completely contains the 95% HDI of the posterior of that parameter.

The Model Comparison Approach

Are Different Groups Equal or Not?

Data are downloaded from JWarmenhoven's implementation.

music_df: pd.DataFrame = pd.read_csv(
    'datasets/background_music.csv', dtype=dict(CondOfSubj='category'))
music_df['CondOfSubj'] = (music_df['CondOfSubj']
                          .cat.rename_categories(
                              {'1': 'Das Kruschke', '2': 'Mozart',
                               '3': 'Bach', '4': 'Beethoven'}))
music_df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 80 entries, 0 to 79
Data columns (total 3 columns):
 #   Column       Non-Null Count  Dtype   
---  ------       --------------  -----   
 0   CondOfSubj   80 non-null     category
 1   nTrlOfSubj   80 non-null     int64   
 2   nCorrOfSubj  80 non-null     int64   
dtypes: category(1), int64(2)
memory usage: 1.7 KB
# Wrong implementation, still don't understand why it gives wrong results?
def words_recall_in_different_music_genres_model_comparison(
        genre: jnp.ndarray, nb_trials: jnp.ndarray, nb_corrects: jnp.ndarray, nb_genres: int):
    assert genre.shape[0] == nb_trials.shape[0] == nb_corrects.shape[0]
    nb_obs = genre.shape[0]

    # We will compare between two models:
    # * Model 0 uses condition-specific omega,
    # * Model 1 uses the same omega for all conditions.
    model_probs = jnp.array([0.5, 0.5])
    model = numpyro.sample('model', dist.Categorical(model_probs))

    aP = 1.
    bP = 1.

    # Model 0's omegas.
    a = jnp.repeat(aP, nb_genres)
    b = jnp.repeat(bP, nb_genres)
    omega = numpyro.sample('omega', dist.Beta(a, b))

    # Model 1's omega1.
    a0 = aP
    b0 = bP
    # correspond to omega0 in the book.
    omega1 = numpyro.sample('omega1', dist.Beta(a0, b0))

    # Kappa's prior.
    kappa_minus_two = numpyro.sample(
        '_kappa-2', dist_utils.gammaDistFromModeStd(20, 20).expand([nb_genres]))
    kappa = numpyro.deterministic('kappa', kappa_minus_two + 2)

    ome = jnp.where(model == 0, omega, omega1)
    aBeta = ome * (kappa - 2) + 1
    bBeta = (1 - ome) * (kappa - 2) + 1

    # Observations.
    with numpyro.plate('obs', nb_obs) as idx:
        theta = numpyro.sample(
            'theta', dist.Beta(aBeta[genre[idx]], bBeta[genre[idx]]))
        numpyro.sample(
            'correct', dist.Binomial(nb_trials[idx], theta), obs=nb_corrects[idx])


# Correct implementation.
def words_recall_in_different_music_genres_model_comparison_1(
        genre: jnp.ndarray, nb_trials: jnp.ndarray, nb_corrects: jnp.ndarray, nb_genres: int):
    assert genre.shape[0] == nb_trials.shape[0] == nb_corrects.shape[0]
    nb_obs = genre.shape[0]

    # We will compare between two models:
    # * Model 0 uses condition-specific omega,
    # * Model 1 uses the same omega for all conditions.
    model_probs = jnp.array([0.5, 0.5])
    model = numpyro.sample('model', dist.Categorical(model_probs))

    aP = 1.
    bP = 1.

    # Model 0's omegas.
    a = jnp.c_[jnp.repeat(aP, nb_genres),
               [.40 * 125, .50 * 125, .51 * 125, .52 * 125]]
    b = jnp.c_[jnp.repeat(bP, nb_genres),
               [.60 * 125, .50 * 125, .49 * 125, .48 * 125]]
    omega = numpyro.sample('omega', dist.Beta(a[:, model], b[:, model]))

    # Model 1's omega1.
    a0 = jnp.array([.48 * 500, aP])
    b0 = jnp.array([.52 * 500, bP])
    # correspond to omega0 in the book.
    omega1 = numpyro.sample('omega1', dist.Beta(a0[model], b0[model]))

    # Kappa's prior.
    kappa_minus_two = numpyro.sample(
        '_kappa-2', dist_utils.gammaDistFromModeStd(20, 20).expand([nb_genres]))
    kappa = numpyro.deterministic('kappa', kappa_minus_two + 2)

    ome = jnp.where(model == 0, omega, omega1)
    aBeta = ome * (kappa - 2) + 1
    bBeta = (1 - ome) * (kappa - 2) + 1

    # Observations.
    with numpyro.plate('obs', nb_obs) as idx:
        theta = numpyro.sample(
            'theta', dist.Beta(aBeta[genre[idx]], bBeta[genre[idx]]))
        numpyro.sample(
            'correct', dist.Binomial(nb_trials[idx], theta), obs=nb_corrects[idx])


kernel = DiscreteHMCGibbs(
    NUTS(words_recall_in_different_music_genres_model_comparison_1,
         init_strategy=init_to_median))
mcmc = MCMC(kernel, num_warmup=1000, num_samples=20000, num_chains=4)
mcmc.run(
    random.PRNGKey(0),
    genre=jnp.array(music_df['CondOfSubj'].cat.codes.values),
    nb_trials=jnp.array(music_df['nTrlOfSubj'].values),
    nb_corrects=jnp.array(music_df['nCorrOfSubj'].values),
    nb_genres=music_df['CondOfSubj'].cat.categories.size,
)
mcmc.print_summary()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
                 mean       std    median      5.0%     95.0%     n_eff     r_hat
_kappa-2[0]     49.56     22.33     45.88     15.52     82.53  27463.09      1.00
_kappa-2[1]     41.19     19.66     37.48     11.65     69.38  38810.96      1.00
_kappa-2[2]     49.30     21.80     45.66     16.17     81.33  44666.67      1.00
_kappa-2[3]     42.53     20.15     38.90     12.70     72.20  39571.61      1.00
      model      0.84      0.37      1.00      0.00      1.00   1128.91      1.00
   omega[0]      0.40      0.04      0.40      0.33      0.47 145561.25      1.00
   omega[1]      0.50      0.04      0.50      0.43      0.57 153513.15      1.00
   omega[2]      0.51      0.04      0.51      0.44      0.58 139553.02      1.00
   omega[3]      0.52      0.04      0.52      0.45      0.59 149503.36      1.00
     omega1      0.48      0.02      0.48      0.45      0.51  60299.67      1.00
   theta[0]      0.45      0.07      0.45      0.34      0.55  10890.76      1.00
   theta[1]      0.43      0.07      0.43      0.32      0.54  12034.24      1.00
   theta[2]      0.45      0.07      0.45      0.34      0.56  11103.20      1.00
   theta[3]      0.42      0.07      0.42      0.31      0.53  12907.55      1.00
   theta[4]      0.43      0.07      0.43      0.33      0.54  11422.34      1.00
   theta[5]      0.46      0.07      0.46      0.36      0.57  10524.79      1.00
   theta[6]      0.45      0.07      0.45      0.34      0.55  11632.65      1.00
   theta[7]      0.46      0.07      0.46      0.35      0.57  10061.44      1.00
   theta[8]      0.46      0.07      0.46      0.35      0.57  10542.69      1.00
   theta[9]      0.45      0.07      0.45      0.34      0.55  10746.85      1.00
  theta[10]      0.40      0.07      0.40      0.29      0.51  13807.26      1.00
  theta[11]      0.45      0.07      0.45      0.34      0.56  11168.99      1.00
  theta[12]      0.49      0.07      0.49      0.38      0.60   9467.17      1.00
  theta[13]      0.46      0.07      0.46      0.35      0.57  10451.71      1.00
  theta[14]      0.45      0.07      0.45      0.34      0.55  11181.41      1.00
  theta[15]      0.43      0.07      0.43      0.32      0.54  12156.63      1.00
  theta[16]      0.45      0.07      0.45      0.34      0.56  11141.99      1.00
  theta[17]      0.45      0.07      0.45      0.34      0.55  11097.46      1.00
  theta[18]      0.46      0.07      0.46      0.35      0.57  10037.01      1.00
  theta[19]      0.45      0.07      0.45      0.34      0.55  11361.81      1.00
  theta[20]      0.47      0.07      0.47      0.37      0.58 120686.74      1.00
  theta[21]      0.52      0.07      0.52      0.41      0.64 114740.27      1.00
  theta[22]      0.47      0.07      0.47      0.37      0.58 129171.48      1.00
  theta[23]      0.44      0.07      0.44      0.33      0.55 110378.77      1.00
  theta[24]      0.46      0.07      0.46      0.35      0.56 115183.91      1.00
  theta[25]      0.49      0.07      0.49      0.38      0.60 125971.77      1.00
  theta[26]      0.47      0.07      0.47      0.36      0.58 124348.24      1.00
  theta[27]      0.51      0.07      0.51      0.40      0.62 120238.40      1.00
  theta[28]      0.54      0.07      0.54      0.43      0.65 112978.59      1.00
  theta[29]      0.44      0.07      0.44      0.33      0.55 107098.79      1.00
  theta[30]      0.52      0.07      0.52      0.42      0.64 116779.44      1.00
  theta[31]      0.56      0.07      0.56      0.45      0.67  99170.36      1.00
  theta[32]      0.52      0.07      0.52      0.41      0.63 121903.15      1.00
  theta[33]      0.56      0.07      0.56      0.45      0.67 106557.34      1.00
  theta[34]      0.51      0.07      0.51      0.40      0.62 123257.62      1.00
  theta[35]      0.47      0.07      0.47      0.36      0.58 128602.75      1.00
  theta[36]      0.44      0.07      0.44      0.33      0.55 108448.02      1.00
  theta[37]      0.49      0.07      0.49      0.39      0.60 121485.76      1.00
  theta[38]      0.51      0.07      0.51      0.40      0.62 123399.26      1.00
  theta[39]      0.40      0.07      0.41      0.29      0.52  92206.40      1.00
  theta[40]      0.51      0.06      0.51      0.40      0.61  70763.74      1.00
  theta[41]      0.43      0.06      0.43      0.32      0.53  65575.69      1.00
  theta[42]      0.51      0.06      0.50      0.40      0.61  68948.68      1.00
  theta[43]      0.51      0.06      0.51      0.40      0.61  65938.41      1.00
  theta[44]      0.52      0.06      0.52      0.42      0.63  73484.39      1.00
  theta[45]      0.51      0.06      0.51      0.40      0.61  76788.76      1.00
  theta[46]      0.51      0.06      0.51      0.40      0.61  68323.47      1.00
  theta[47]      0.49      0.06      0.49      0.39      0.59  68172.99      1.00
  theta[48]      0.51      0.06      0.51      0.40      0.61  75168.43      1.00
  theta[49]      0.43      0.06      0.43      0.32      0.54  62342.94      1.00
  theta[50]      0.51      0.06      0.51      0.40      0.61  71609.83      1.00
  theta[51]      0.52      0.06      0.52      0.42      0.63  64943.94      1.00
  theta[52]      0.52      0.06      0.52      0.41      0.62  68927.29      1.00
  theta[53]      0.52      0.06      0.52      0.42      0.63  69587.85      1.00
  theta[54]      0.46      0.06      0.46      0.36      0.56  65720.75      1.00
  theta[55]      0.51      0.06      0.51      0.40      0.61  73615.24      1.00
  theta[56]      0.48      0.06      0.48      0.37      0.58  73600.38      1.00
  theta[57]      0.52      0.06      0.52      0.42      0.62  70418.88      1.00
  theta[58]      0.46      0.06      0.46      0.36      0.57  71170.21      1.00
  theta[59]      0.48      0.06      0.48      0.37      0.58  68928.83      1.00
  theta[60]      0.42      0.07      0.43      0.31      0.54  52877.42      1.00
  theta[61]      0.53      0.07      0.52      0.42      0.64  51691.33      1.00
  theta[62]      0.53      0.07      0.52      0.42      0.64  52253.89      1.00
  theta[63]      0.53      0.07      0.52      0.42      0.64  54330.56      1.00
  theta[64]      0.56      0.07      0.56      0.45      0.67  49418.48      1.00
  theta[65]      0.51      0.07      0.51      0.40      0.62  54325.89      1.00
  theta[66]      0.51      0.07      0.51      0.40      0.62  55596.97      1.00
  theta[67]      0.48      0.07      0.48      0.37      0.58  52616.21      1.00
  theta[68]      0.51      0.07      0.51      0.40      0.62  56049.09      1.00
  theta[69]      0.48      0.07      0.48      0.37      0.58  52557.33      1.00
  theta[70]      0.46      0.07      0.46      0.35      0.57  47395.47      1.00
  theta[71]      0.56      0.07      0.56      0.45      0.67  54799.28      1.00
  theta[72]      0.49      0.07      0.49      0.38      0.60  52106.80      1.00
  theta[73]      0.49      0.07      0.49      0.38      0.60  51248.14      1.00
  theta[74]      0.42      0.07      0.43      0.31      0.54  51196.66      1.00
  theta[75]      0.56      0.07      0.56      0.45      0.67  49397.63      1.00
  theta[76]      0.48      0.07      0.48      0.37      0.58  64175.57      1.00
  theta[77]      0.51      0.07      0.51      0.40      0.62  52205.48      1.00
  theta[78]      0.49      0.07      0.49      0.38      0.60  58279.91      1.00
  theta[79]      0.48      0.07      0.48      0.37      0.58  52340.74      1.00

idata = az.from_numpyro(
    mcmc,
    coords=dict(genre=music_df['CondOfSubj'].cat.categories),
    dims=dict(omega=['genre'], kappa=['genre']))
az.plot_trace(idata)
plt.tight_layout()
from itertools import combinations  # noqa

fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(12, 6))
posterior = idata['posterior']

for (left, right), ax in zip(combinations(music_df['CondOfSubj'].cat.categories, 2),
                             axes.flatten()):
    left_omega = posterior['omega'].sel(genre=left).values
    right_omega = posterior['omega'].sel(genre=right).values

    diff = left_omega - right_omega
    az.plot_posterior(diff, hdi_prob=.95, ref_val=0.0,
                      point_estimate='mode', ax=ax)
    ax.set_title('')
    ax.set_xlabel(f'$\\omega_{{{left}}} - \\omega_{{{right}}}$')

fig.tight_layout()
model_idx = posterior['model'].values.flatten()
sns.lineplot(x=np.indices(model_idx.shape).flatten(), y=model_idx)
plt.tight_layout()