%cd ..
%load_ext autoreload
%autoreload 2
/home/runner/work/numpyro-doing-bayesian/numpyro-doing-bayesian
import arviz as az
from functools import reduce
import jax.random as random
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs
import numpyro_glm
import numpyro_glm.metric.models as glm_metric
from scipy.stats import pearsonr

Chapter 18: Metric Predicted Variable with Multiple Metric Predictors

Multiple Linear Regression

df_SAT = pd.read_csv('datasets/Guber1999data.csv')
df_SAT.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 50 entries, 0 to 49
Data columns (total 8 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   State      50 non-null     object 
 1   Spend      50 non-null     float64
 2   StuTeaRat  50 non-null     float64
 3   Salary     50 non-null     float64
 4   PrcntTake  50 non-null     int64  
 5   SATV       50 non-null     int64  
 6   SATM       50 non-null     int64  
 7   SATT       50 non-null     int64  
dtypes: float64(3), int64(4), object(1)
memory usage: 3.2+ KB
df_SAT.describe()
Spend StuTeaRat Salary PrcntTake SATV SATM SATT
count 50.000000 50.000000 50.000000 50.000000 50.000000 50.000000 50.000000
mean 5.905260 16.858000 34.828920 35.240000 457.140000 508.780000 965.920000
std 1.362807 2.266355 5.941265 26.762417 35.175948 40.204726 74.820558
min 3.656000 13.800000 25.994000 4.000000 401.000000 443.000000 844.000000
25% 4.881750 15.225000 30.977500 9.000000 427.250000 474.750000 897.250000
50% 5.767500 16.600000 33.287500 28.000000 448.000000 497.500000 945.500000
75% 6.434000 17.575000 38.545750 63.000000 490.250000 539.500000 1032.000000
max 9.774000 24.300000 50.045000 81.000000 516.000000 592.000000 1107.000000
y_SAT = df_SAT.SATT.values
x_SAT_names = ['Spend', 'PrcntTake']
x_SAT = df_SAT[x_SAT_names].values

df_SAT[x_SAT_names].corr()
Spend PrcntTake
Spend 1.000000 0.592627
PrcntTake 0.592627 1.000000
key = random.PRNGKey(0)
model = NUTS(glm_metric.multi_metric_predictors_robust)
mcmc = MCMC(model, num_warmup=1000, num_samples=20000)
mcmc.run(key, y_SAT, x_SAT)
mcmc.print_summary()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        nu     34.12     29.86     25.01      2.05     72.35  17681.07      1.00
     zb[0]      0.24      0.08      0.24      0.11      0.36  15246.60      1.00
     zb[1]     -1.03      0.08     -1.03     -1.16     -0.90  15847.48      1.00
       zb0     -0.00      0.06     -0.00     -0.10      0.10  18559.54      1.00
    zsigma      0.43      0.05      0.42      0.34      0.51  15840.06      1.00

Number of divergences: 0
idata_SAT = az.from_numpyro(
    mcmc,
    coords=dict(predictors=[0, 1]),
    dims=dict(b_=['predictors']))
posterior_SAT = idata_SAT.posterior

fig_SAT_posteriors, axes = plt.subplots(nrows=2, ncols=3, figsize=(12, 8))
to_plot_posteriors_SAT = {
    'Intercept': posterior_SAT['b0'].values.flatten(),
    'Spend Coeff': posterior_SAT['b_'].sel(dict(predictors=0)).values.flatten(),
    'PrcntTake Coeff': posterior_SAT['b_'].sel(dict(predictors=1)).values.flatten(),
    'Scale': posterior_SAT['sigma'].values.flatten(),
    'Log Normality': np.log(posterior_SAT['nu'].values.flatten()),
}

for ax, (title, values) in zip(axes.flatten(), to_plot_posteriors_SAT.items()):
    az.plot_posterior(values, point_estimate='mode', hdi_prob=0.95, ax=ax)
    ax.set_title(title)

fig_SAT_posteriors.tight_layout()
fig_SAT_pairwise_scatter, axes = plt.subplots(
    nrows=5, ncols=5, figsize=(20, 20))

for ith, ith_var in enumerate(to_plot_posteriors_SAT.keys()):
    for jth, jth_var in enumerate(to_plot_posteriors_SAT.keys()):
        ax = axes[jth, ith]

        if ith == jth:
            numpyro_glm.plot_text(ith_var, ax)
        elif ith < jth:
            ith_var_data = to_plot_posteriors_SAT[ith_var]
            jth_var_data = to_plot_posteriors_SAT[jth_var]

            corr, _ = pearsonr(ith_var_data, jth_var_data)
            numpyro_glm.plot_text(f'{corr:.2f}', ax)
        else:
            ith_var_data = to_plot_posteriors_SAT[ith_var]
            jth_var_data = to_plot_posteriors_SAT[jth_var]

            ax.scatter(ith_var_data, jth_var_data)

fig_SAT_pairwise_scatter.tight_layout()

Multiplicative Interaction of Metric Predictors

df_SAT['Spend_Prcnt'] = df_SAT['Spend'] * df_SAT['PrcntTake']
x_SAT_names = ['Spend', 'PrcntTake', 'Spend_Prcnt']
x_SAT_multiplicative = df_SAT[x_SAT_names].values

df_SAT[x_SAT_names].corr()
Spend PrcntTake Spend_Prcnt
Spend 1.000000 0.592627 0.775025
PrcntTake 0.592627 1.000000 0.951146
Spend_Prcnt 0.775025 0.951146 1.000000
key = random.PRNGKey(0)
model = NUTS(glm_metric.multi_metric_predictors_robust)
mcmc = MCMC(model, num_warmup=1000, num_samples=20000)
mcmc.run(key, y_SAT, x_SAT_multiplicative)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        nu     35.85     30.13     27.08      2.37     75.30  14716.68      1.00
     zb[0]      0.03      0.15      0.03     -0.21      0.26   7784.35      1.00
     zb[1]     -1.51      0.30     -1.51     -2.00     -1.03   7354.59      1.00
     zb[2]      0.63      0.38      0.63      0.03      1.27   6902.05      1.00
       zb0     -0.00      0.06     -0.00     -0.11      0.10  14873.94      1.00
    zsigma      0.42      0.05      0.42      0.34      0.50  11498.02      1.00

Number of divergences: 0
idata_SAT_multiplicative = az.from_numpyro(
    mcmc,
    coords=dict(predictors=[0, 1, 2]),
    dims=dict(b_=['predictors']))
posterior_SAT_multiplicative = idata_SAT_multiplicative.posterior

fig_SAT_multiplicative_posteriors, axes = plt.subplots(
    nrows=2, ncols=3, figsize=(12, 8))
to_plot_posteriors_SAT_multiplicative = {
    'Intercept': posterior_SAT_multiplicative['b0'].values.flatten(),
    'Spend Coeff': posterior_SAT_multiplicative['b_'].sel(dict(predictors=0)).values.flatten(),
    'PrcntTake Coeff': posterior_SAT_multiplicative['b_'].sel(dict(predictors=1)).values.flatten(),
    'Spend_Prcnt Coeff': posterior_SAT_multiplicative['b_'].sel(dict(predictors=2)).values.flatten(),
    'Scale': posterior_SAT_multiplicative['sigma'].values.flatten(),
    'Log Normality': np.log(posterior_SAT_multiplicative['nu'].values.flatten()),
}

for ax, (title, values) in zip(axes.flatten(), to_plot_posteriors_SAT_multiplicative.items()):
    az.plot_posterior(values, point_estimate='mode', hdi_prob=0.95, ax=ax)
    ax.set_title(title)

fig_SAT_multiplicative_posteriors.tight_layout()
fig_SAT_slopes, axes = plt.subplots(nrows=2, figsize=(12, 8))

# Spend as a function of Percent Take.
ax = axes[0]
percent_take = np.linspace(4, 80, 20)
spend_slopes = (to_plot_posteriors_SAT_multiplicative['Spend Coeff'].reshape(-1, 1)
                + to_plot_posteriors_SAT_multiplicative['Spend_Prcnt Coeff'].reshape(-1, 1) * percent_take.reshape(1, -1))
spend_hdis = az.hdi(spend_slopes, hdi_prob=0.95)
spend_medians = np.median(spend_slopes, axis=0)

ax.errorbar(percent_take, spend_medians,
            yerr=spend_hdis[:, 1] - spend_hdis[:, 0], fmt='o', label='Spend Slope Median')
ax.set_title('Spend Slope vs. Percent Take')
ax.set_xlabel('Percent Take')
ax.set_ylabel('Spend Slope')
ax.legend()

# Percent Take as a function of Spend.
ax = axes[1]
spend = np.linspace(3, 10, 20)
prcnt_slopes = (to_plot_posteriors_SAT_multiplicative['PrcntTake Coeff'].reshape(-1, 1)
                + to_plot_posteriors_SAT_multiplicative['Spend_Prcnt Coeff'].reshape(-1, 1) * spend.reshape(1, -1))
prcnt_hdis = az.hdi(prcnt_slopes, hdi_prob=0.95)
prcnt_medians = np.median(prcnt_slopes, axis=0)

ax.errorbar(spend, prcnt_medians,
            yerr=prcnt_hdis[:, 1] - prcnt_hdis[:, 0], fmt='o', label='Percent Take Slope Median')
ax.set_title('Percent Take Slope vs. Spend')
ax.set_xlabel('Spend')
ax.set_ylabel('Percent Take Slope')
ax.legend()

fig_SAT_slopes.tight_layout()
/tmp/ipykernel_5622/1493046308.py:8: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  spend_hdis = az.hdi(spend_slopes, hdi_prob=0.95)
/tmp/ipykernel_5622/1493046308.py:23: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  prcnt_hdis = az.hdi(prcnt_slopes, hdi_prob=0.95)

Shrinkage of Regression Coefficients

Without Shrinkage Model

def normalize(values):
    return (values - np.mean(values)) / np.std(values)


nb_random_preds = 12

df_SAT_random = df_SAT.copy()
for i in range(nb_random_preds):
    df_SAT_random[f'xRand{i}'] = normalize(
        np.random.normal(0, 1, size=len(df_SAT)))

df_SAT_random.describe()
Spend StuTeaRat Salary PrcntTake SATV SATM SATT Spend_Prcnt xRand0 xRand1 xRand2 xRand3 xRand4 xRand5 xRand6 xRand7 xRand8 xRand9 xRand10 xRand11
count 50.000000 50.000000 50.000000 50.000000 50.000000 50.000000 50.000000 50.000000 5.000000e+01 5.000000e+01 5.000000e+01 5.000000e+01 5.000000e+01 5.000000e+01 5.000000e+01 5.000000e+01 5.000000e+01 5.000000e+01 5.000000e+01 5.000000e+01
mean 5.905260 16.858000 34.828920 35.240000 457.140000 508.780000 965.920000 229.283380 9.992007e-18 -1.332268e-17 2.164935e-17 8.881784e-18 6.661338e-18 5.329071e-17 -1.110223e-17 -2.248202e-17 -1.554312e-17 6.217249e-17 6.064593e-17 1.720846e-17
std 1.362807 2.266355 5.941265 26.762417 35.175948 40.204726 74.820558 206.179695 1.010153e+00 1.010153e+00 1.010153e+00 1.010153e+00 1.010153e+00 1.010153e+00 1.010153e+00 1.010153e+00 1.010153e+00 1.010153e+00 1.010153e+00 1.010153e+00
min 3.656000 13.800000 25.994000 4.000000 401.000000 443.000000 844.000000 14.624000 -2.474430e+00 -1.817330e+00 -2.297824e+00 -2.144337e+00 -1.961650e+00 -1.926837e+00 -1.865986e+00 -2.059300e+00 -2.350770e+00 -2.058238e+00 -2.039108e+00 -2.460494e+00
25% 4.881750 15.225000 30.977500 9.000000 427.250000 474.750000 897.250000 52.845750 -5.298635e-01 -8.647588e-01 -4.904526e-01 -7.310847e-01 -7.388813e-01 -8.112377e-01 -9.310771e-01 -6.392889e-01 -6.329673e-01 -6.910836e-01 -6.363541e-01 -6.939741e-01
50% 5.767500 16.600000 33.287500 28.000000 448.000000 497.500000 945.500000 148.263000 -1.510935e-02 -1.399603e-01 -2.512563e-02 -1.625857e-02 4.272073e-02 4.995051e-02 -2.949387e-02 -1.435093e-02 -1.938106e-02 5.835320e-02 -3.873375e-02 -4.798673e-02
75% 6.434000 17.575000 38.545750 63.000000 490.250000 539.500000 1032.000000 346.398250 8.011627e-01 8.752898e-01 4.987085e-01 7.955121e-01 7.149399e-01 6.678859e-01 9.008077e-01 5.360744e-01 6.130298e-01 6.779854e-01 6.900932e-01 8.649452e-01
max 9.774000 24.300000 50.045000 81.000000 516.000000 592.000000 1107.000000 714.177000 1.884783e+00 2.050287e+00 2.200799e+00 2.041537e+00 2.117672e+00 2.432630e+00 1.858883e+00 2.807752e+00 2.491902e+00 2.129984e+00 1.926482e+00 1.764810e+00
x_SAT_random_cols = ['Spend', 'PrcntTake',
                     *(f'xRand{i}' for i in range(nb_random_preds))]

x_SAT_random = df_SAT_random[x_SAT_random_cols].values
y_SAT_random = df_SAT_random['SATT'].values

key = random.PRNGKey(0)
kernel = NUTS(glm_metric.multi_metric_predictors_robust)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=20000)
mcmc.run(key, y_SAT_random, x_SAT_random)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        nu     39.17     31.16     30.42      2.92     80.52  19622.08      1.00
     zb[0]      0.11      0.09      0.11     -0.03      0.26  13750.71      1.00
     zb[1]     -1.12      0.09     -1.12     -1.26     -0.97  13247.61      1.00
     zb[2]      0.08      0.07      0.08     -0.03      0.19  17569.61      1.00
     zb[3]      0.18      0.07      0.18      0.07      0.30  14116.48      1.00
     zb[4]      0.11      0.06      0.11      0.01      0.21  19537.87      1.00
     zb[5]      0.02      0.06      0.02     -0.08      0.11  19015.19      1.00
     zb[6]     -0.07      0.06     -0.07     -0.17      0.04  17415.77      1.00
     zb[7]      0.08      0.07      0.08     -0.03      0.19  15395.65      1.00
     zb[8]     -0.20      0.06     -0.20     -0.30     -0.10  17794.77      1.00
     zb[9]      0.11      0.07      0.12      0.01      0.23  17386.80      1.00
    zb[10]     -0.09      0.07     -0.09     -0.20      0.03  14826.41      1.00
    zb[11]     -0.02      0.06     -0.02     -0.13      0.08  18492.38      1.00
    zb[12]     -0.14      0.07     -0.14     -0.25     -0.03  16583.01      1.00
    zb[13]     -0.00      0.07     -0.00     -0.11      0.11  15223.49      1.00
       zb0      0.00      0.06      0.00     -0.09      0.09  24996.96      1.00
    zsigma      0.38      0.05      0.37      0.30      0.46  12303.03      1.00

Number of divergences: 0
idata_SAT_random = az.from_numpyro(
    mcmc,
    coords=dict(predictors=list(range(14))),
    dims=dict(b_=['predictors']))
posterior_SAT_random = idata_SAT_random.posterior

fig_SAT_random_posteriors, axes = plt.subplots(
    nrows=4, ncols=3, figsize=(12, 8))
to_plot_posteriors_SAT_random = {
    'Intercept': posterior_SAT_random['b0'].values.flatten(),
    'Spend Coeff': posterior_SAT_random['b_'].sel(dict(predictors=0)).values.flatten(),
    'PrcntTake Coeff': posterior_SAT_random['b_'].sel(dict(predictors=1)).values.flatten(),
    **{f'xRand{i} Coeff': posterior_SAT_random['b_'].sel(dict(predictors=i + 2)).values.flatten()
        for i in [0, 1, 2, 9, 10, 11]},
    'Scale': posterior_SAT_random['sigma'].values.flatten(),
    'Log Normality': np.log(posterior_SAT_random['nu'].values.flatten()),
}

for ax, (title, values) in zip(axes.flatten(), to_plot_posteriors_SAT_random.items()):
    az.plot_posterior(values, point_estimate='mode', hdi_prob=0.95, ax=ax)
    ax.set_title(title)

fig_SAT_random_posteriors.tight_layout()

Shrinkage Model

key = random.PRNGKey(0)
kernel = NUTS(glm_metric.multi_metric_predictors_robust_with_shrinkage)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=20000)
mcmc.run(key, y_SAT_random, x_SAT_random)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        nu     36.16     31.33     26.67      1.93     77.28  21658.70      1.00
  sigma_b_      0.06      0.03      0.05      0.01      0.10   7546.81      1.00
       zb0      0.00      0.06      0.01     -0.09      0.10  20701.41      1.00
    zb_[0]      0.08      0.07      0.07     -0.03      0.20  11401.46      1.00
    zb_[1]     -1.00      0.08     -1.00     -1.15     -0.87   9731.16      1.00
    zb_[2]      0.02      0.05      0.01     -0.05      0.10  14435.66      1.00
    zb_[3]      0.11      0.07      0.11     -0.01      0.21  10257.99      1.00
    zb_[4]      0.06      0.05      0.05     -0.03      0.15  12437.73      1.00
    zb_[5]      0.00      0.04      0.00     -0.06      0.07  19218.33      1.00
    zb_[6]     -0.02      0.04     -0.01     -0.09      0.05  18700.10      1.00
    zb_[7]      0.02      0.05      0.01     -0.06      0.10  15767.57      1.00
    zb_[8]     -0.11      0.07     -0.11     -0.21      0.00   9147.13      1.00
    zb_[9]      0.04      0.05      0.03     -0.04      0.12  10957.62      1.00
   zb_[10]     -0.02      0.05     -0.01     -0.10      0.06  12886.44      1.00
   zb_[11]     -0.01      0.04     -0.01     -0.08      0.06  20353.49      1.00
   zb_[12]     -0.09      0.06     -0.08     -0.19      0.01  11701.68      1.00
   zb_[13]      0.02      0.05      0.02     -0.05      0.10  15251.18      1.00
    zsigma      0.38      0.05      0.38      0.30      0.47  11712.16      1.00

Number of divergences: 0
idata_SAT_random_shrinkage = az.from_numpyro(
    mcmc,
    coords=dict(predictors=list(range(14))),
    dims=dict(b_=['predictors']))
posterior_SAT_random_shrinkage = idata_SAT_random_shrinkage.posterior

fig_SAT_random_posteriors_shrinkage, axes = plt.subplots(
    nrows=4, ncols=3, figsize=(12, 8))
to_plot_posteriors_SAT_random_shrinkage = {
    'Intercept': posterior_SAT_random_shrinkage['b0'].values.flatten(),
    'Spend Coeff': posterior_SAT_random_shrinkage['b_'].sel(dict(predictors=0)).values.flatten(),
    'PrcntTake Coeff': posterior_SAT_random_shrinkage['b_'].sel(dict(predictors=1)).values.flatten(),
    **{f'xRand{i} Coeff': posterior_SAT_random_shrinkage['b_'].sel(dict(predictors=i + 2)).values.flatten()
        for i in [0, 1, 2, 9, 10, 11]},
    'Scale': posterior_SAT_random_shrinkage['sigma'].values.flatten(),
    'Log Normality': np.log(posterior_SAT_random_shrinkage['nu'].values.flatten()),
}

for ax, (title, values) in zip(axes.flatten(), to_plot_posteriors_SAT_random_shrinkage.items()):
    az.plot_posterior(values, point_estimate='mode', hdi_prob=0.95, ax=ax)
    ax.set_title(title)

fig_SAT_random_posteriors_shrinkage.tight_layout()

Variable Selection

x_SAT_all_names = ['Spend', 'StuTeaRat', 'Salary', 'PrcntTake']
x_SAT_all = df_SAT[x_SAT_all_names].values

key = random.PRNGKey(0)
kernel = DiscreteHMCGibbs(
    NUTS(glm_metric.multi_metric_predictors_robust_with_selection), modified=True)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=40000)
mcmc.run(key, y_SAT, x_SAT_all)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
  delta[0]      0.62      0.48      1.00      0.00      1.00   1435.46      1.00
  delta[1]      0.17      0.38      0.00      0.00      1.00   4008.72      1.00
  delta[2]      0.21      0.41      0.00      0.00      1.00   2580.31      1.00
  delta[3]      1.00      0.00      1.00      1.00      1.00       nan       nan
        nu     35.98     30.75     26.95      2.22     76.60  17600.66      1.00
    sigmab      2.28      3.37      1.15      0.05      5.36   2251.35      1.00
       zb0     -0.00      0.07     -0.00     -0.11      0.10  18422.70      1.00
    zb_[0]     -0.15     14.88      0.21     -5.46      5.47   1200.77      1.00
    zb_[1]     -0.16     19.46     -0.08    -10.81     11.48   2142.64      1.00
    zb_[2]     -1.11     23.37      0.09    -10.97     11.72    726.87      1.00
    zb_[3]     -0.99      0.10     -0.99     -1.14     -0.82   3466.52      1.00
    zsigma      0.44      0.06      0.44      0.35      0.53   5578.80      1.00

idata_SAT_all_with_selection = az.from_numpyro(
    mcmc,
    coords=dict(predictors=list(range(4))),
    dims=dict(b_=['predictors'], zb=['predictors'], delta=['predictors']),
)
posterior_SAT_all_with_selection = idata_SAT_all_with_selection.posterior


def plot_selected_model_posteriors(predictors):
    PREDICTORS_NAME = {i: name for i, name in enumerate(x_SAT_all_names)}

    # Mask to differentiate which coefficients values should be included
    # in the posterior plot.
    mask = reduce(
        lambda acc, p: acc & (posterior_SAT_all_with_selection['delta'].sel(
            predictors=p).values == (1 if p in predictors else 0)),
        list(PREDICTORS_NAME.keys())[1:],
        posterior_SAT_all_with_selection['delta'].sel(predictors=0).values == (1 if 0 in predictors else 0))
    mask = mask.astype(bool)

    # Calculate the model's probability.
    model_prob = mask.sum() / np.prod(mask.shape)

    # Create figure.
    fig, axes = plt.subplots(
        ncols=len(PREDICTORS_NAME.keys()) + 1, figsize=(15, 4))
    fig.suptitle(f'Model Prob = {model_prob:.3f}')
    axes = axes.flatten()

    # Plot the posterior of the intercept.
    ax = axes[0]
    az.plot_posterior(
        posterior_SAT_all_with_selection['b0'].values[mask],
        point_estimate='mode',
        hdi_prob=0.95,
        ax=ax)
    ax.set_title('Intercept')

    # Plot the posterior of the coefficients.
    for predictor, ax in zip(PREDICTORS_NAME.keys(), axes[1:]):
        if predictor in predictors:
            az.plot_posterior(
                posterior_SAT_all_with_selection['b_'].sel(
                    predictors=predictor).values[mask],
                point_estimate='mode',
                hdi_prob=0.95,
                ax=ax)
            ax.set_title(PREDICTORS_NAME[predictor])
        else:
            ax.remove()

    fig.tight_layout()


models_to_plot = [
    [0, 3],
    [3],
    [2, 3],
    [1, 2, 3],
    [0, 2, 3],
    [0, 1, 3],
    [1, 3],
    [0, 1, 2, 3],
]

for model in models_to_plot:
    plot_selected_model_posteriors(model)