[1]:
import os, sys
sys.path.append(os.path.abspath(os.path.join('../../..')))
import numpy as np
import pandas as pd
import datetime
from functools import partial
import tensorflow as tf
[2]:
from bayesflow.forward_inference import *
from bayesflow.amortized_inference import AmortizedPosterior
from bayesflow.networks import InvertibleNetwork, InvariantNetwork
from bayesflow.trainers import Trainer

Introduction

In our previous tutorial, we saw how to perform amortized posterior estimation on the parameters of a compartmental model for disease outbreaks. In this tutorial, we will perform prior sensitivity analysis on our posterior inference using for the first time the prior_batchable_context quantity.

Assessing a model’s sensitivity can be quite challenging in Bayesian analysis, since it typically requires re-estimating the model multiple times on the same data. Naturally, such an approach scales poorly in scenarios where estimating a model even once is prohibitively slow. While problems of efficiency can be somewhat alleviated in the context of likelihood-based Bayesian inference using Markov chain Monte Carlo (MCMC), the computational burden of refitting models becomes even heavier in simulation-based scenarios.

Both core quantities in Bayesian inference, namely, likelihood and prior, depend on some implicit context variables which are considered fixed and thus not further specified as explicit conditioning variables. Typical examples for prior context variables might be as simple as the prior scale (i.e., dispersion) of admissible parameter values or as complex as discrete sets of qualitative expert knowledge. Typical examples for likelihood context variables include various structural decisions regarding the transformation or exogenous experimental factors, such as design matrices, indicator variables, or time scales. In the BayesFlow library, such factors are considered context and should be implemented via ContextGenerator instances. Both Prior and Simulator instances support context generators.

Why is sensitivity analysis important? Modelers might disagree regarding the right set of context variables. For instance, when modeling emerging pandemics, different experts might choose to incorporate different prior knowledge about unknown disease parameters in their analysis and thereby arrive at conflicting substantive conclusions based on the same data. Certainly, the sensitivity of Bayesian inference will depend on the interplay between various factors, such as the model family or the amount of available data. However, in most real-world applications, this interplay is too complex to allow for a straightforward (analytical) calculation of sensitivity. Thus, it is often desirable to explicitly asses the sensitivity of Bayesian analysis to possible choices of context variables by varying the latter along reasonable dimensions and then quantifying the consequences for the resulting inference. It is this assessment that we will illustrate in this notebook.

Defining the Generative Model

We will use the exactly same compartmental model as implemented in our previous tutorial. Our underlying model distinguishes between susceptible, \(S\), infected, \(I\), and recovered, \(R\), individuals with infection and recovery occurring at a constant transmission rate \(\lambda\) and constant recovery rate \(\mu\), respectively. The model dynamics are governed by the following system of ODEs:

\[\begin{split}\begin{align} \frac{dS}{dt} &= -\lambda\,\left(\frac{S\,I}{N}\right) \\ \frac{dI}{dt} &= \lambda\,\left(\frac{S\,I}{N}\right) - \mu\,I \\ \frac{dR}{dt} &= \mu\,I, \end{align}\end{split}\]

with \(N = S + I + R\) denoting the total population size. For the purpose of forward inference (simulation), we will use a time step of \(dt = 1\), corresponding to daily case reports. In addition to the ODE parameters \(\lambda\) and \(\mu\), we consider a reporting delay parameter \(L\) and a dispersion parameter \(\psi\), which affect the number of reported infected individuals via a negative binomial disttribution (https://en.wikipedia.org/wiki/Negative_binomial_distribution).

Let’s first define a global singleton for accessing the numpy random number generator:

[3]:
RNG = np.random.default_rng(2022)

Prior

Our prior is the same as in our previous notebook, with the difference of a new argument called alpha. What this argument does is to scale each marginal prior \(j\) in the following way \(p(\theta_j) \rightarrow p(\theta_j)^{\alpha_j}\).

The above transformation is called power scaling and is descrtibed in more detial in the paper by Noa Kallioinen, Topi Paananen, Paul-Christian Bürkner, and Aki Vehtari:

Detecting and diagnosing prior and likelihood sensitivity with power-scaling (https://arxiv.org/abs/2107.14054)

Essentially, the power scaling method expands or shrinks the prior pdf, thus assuming a sort of hierarchical prior with adaptive sharpness. However, differently to hierarchical Bayesian models, we will not estimate the hyerparameter \(\alpha\), but probe our results with different values of \(\alpha\) in some meaningful range (here \(\alpha \in [0.5 - 2]\)). The neat trick is this: we will perform amortized prior sensitivity analysis. Instead of training one AmortizedPosterior, we will include \(\alpha\) in the generative process and treat it as a context variable over which to amortize. This means, that the scaling factors will simply be concatednated with the outputs of the summary network and passed on to the inference network. In this way, upon convergence, we end up with an \(\alpha\)-aware inference network, which can yield a posterior for each \(\alpha\) we supply during inference.

Context Generator Function

During training, we will generate the scaling parameters from a continuous uniform distribution:

\[\alpha_j \sim \mathcal{U}(0.5, 2)\]

In this way, the inference network will see both simulations with a narrower prior (\(\alpha > 1\)) and a wider prior (\(\alpha < 1\)) than our original prior choice (\(\alpha = 1\)).

[32]:
def alpha_gen():
    """ Generates power-scaling parameters from a uniform distribution.
    """
    return RNG.uniform(0.5, 2, size=5)

Prior Generator Function

We can now also define the function which will generate random draws from the prior. Notice the additional argument specifying the vector of \(\alpha\)-scaling parameters. Note also, that scaling different pdfs has different effect on the corresponding parametrs. For instance, performing power scaling on an exponential pdf simply means dividing it’s shape parameters by \(\alpha\). Further examples are given in https://arxiv.org/abs/2107.14054.

[33]:
def model_prior(alpha):
    """ Generates random draws from the prior given an alpha scaling parameter. Each marginal
    prior has its own scaling parameter. See the paper linked below for details:

    https://arxiv.org/abs/2107.14054
    """

    lambd = RNG.lognormal(mean=np.log(0.4), sigma=0.5 / np.sqrt(alpha[0]) )
    mu = RNG.lognormal(mean=np.log(1/8), sigma=0.2 / np.sqrt(alpha[1]) )
    D = RNG.lognormal(mean=np.log(8), sigma=0.2 / np.sqrt(alpha[2]) )
    I0 = RNG.gamma(shape=2*alpha[3] - alpha[3] + 1, scale=20/alpha[3])
    psi = RNG.exponential(5/alpha[4])
    return np.array([lambd, mu, D, I0, psi])

Great! We can now connect the context generating function with the prior sampler using BayesFlow wrappers:

[12]:
prior_context = ContextGenerator(batchable_context_fun=alpha_gen)
prior = Prior(model_prior, prior_context)

When we now call the Prior object, we can see that the entry batchable_context in the output dictionary is filled with the scaling parameters used to generate each random draw. The context is dubbed batchable, because for each random draw from the prior in a batch of simulations, there is a corresponding context variable. In contrast, a non_batchable_context would be shared across all simulated quantities within a batch.

[16]:
prior(2)
[16]:
{'prior_draws': array([[ 0.33610512,  0.12001227,  7.13881526, 60.35291708,  3.23956216],
        [ 0.50102922,  0.12348521, 10.08823068, 44.44636353,  0.85205942]]),
 'batchable_context': [array([0.62513734, 1.17675062, 0.74885141, 0.50594303, 0.94169066]),
  array([1.33514958, 1.7993603 , 0.81417343, 1.2814965 , 1.25550107])],
 'non_batchable_context': None}

Notice, that the context variables are returned as a list type. This is done in order to deal with cases where the context variables cannot be coerced to a rectangular array or tensor. Our configurator should take care of this!

During training, we will also standardize the prior draws, that is, ensure zero means and unit scale. We will do this purely for technical reasons - neural networks like scaled values. In addition, our current prior ranges differ vastly, so each parameter will contribute disproportionately to the loss function. Here, we will use the estimate_means_and_stds() method of a Prior instance, which will estimate the prior means and standard deviations from random draws. We could have also just taken the analytic marginal means and standard deviations, but these may not be available in all settings (e.g., implicit priors).

[38]:
prior_means, prior_stds = prior.estimate_means_and_stds()

Simulator (Implicit Likelihood Function)

Our simulator remains unchanged from the one specified in our previous notebook:

[39]:
from scipy.stats import nbinom

def convert_params(mu, phi):
    """ Helper function to convert mean/dispersion parameterization of a negative binomial to N and p,
    as expected by numpy.

    See https://en.wikipedia.org/wiki/Negative_binomial_distribution#Alternative_formulations
    """

    r = phi
    var = mu + 1 / r * mu ** 2
    p = (var - mu) / var
    return r, 1 - p

def stationary_SIR(params, N, T, eps=1e-5):
    """Performs a forward simulation from the stationary SIR model given a random draw from the prior,
    """

    # Extract parameters and round I0 and D
    lambd, mu, D, I0, psi = params
    I0 = np.ceil(I0)
    D = int(round(D))

    # Initial conditions
    S, I, R = [N-I0], [I0], [0]

    # Reported new cases
    C = [I0]

    # Simulate T-1 tiemsteps
    for t in range(1, T+D):

        # Calculate new cases
        I_new = lambd * (I[-1]*S[-1]/N)

        # SIR equations
        S_t = S[-1] - I_new
        I_t = np.clip(I[-1] + I_new - mu*I[-1], 0., N)
        R_t = np.clip(R[-1] + mu*I[-1], 0., N)

        # Track
        S.append(S_t)
        I.append(I_t)
        R.append(R_t)
        C.append(I_new)

    reparam = convert_params(np.clip(np.array(C[D:]), 0, N) + eps, psi)
    C_obs = RNG.negative_binomial(reparam[0], reparam[1])
    return C_obs[:, np.newaxis]

Loading Real Data

We will define a simple helper function to load the actually reported cases for the first 2 weeks in Germany.

[40]:
def load_data():
    """Helper function to load cumulative cases and transform them to new cases."""

    confirmed_cases_url = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv'
    confirmed_cases = pd.read_csv(confirmed_cases_url, sep=',')

    date_data_begin = datetime.date(2020,3,1)
    date_data_end = datetime.date(2020,3,15)
    format_date = lambda date_py: '{}/{}/{}'.format(date_py.month, date_py.day,
                                                     str(date_py.year)[2:4])
    date_formatted_begin = format_date(date_data_begin)
    date_formatted_end = format_date(date_data_end)

    cases_obs =  np.array(
        confirmed_cases.loc[confirmed_cases["Country/Region"] == "Germany",
                            date_formatted_begin:date_formatted_end])[0]
    new_cases_obs = np.diff(cases_obs)
    return new_cases_obs

Again, we can collect the simulator settings, together with the observed data, into a dictionary:

[41]:
config = {
    'T': 14,
    'N': 83e6,
    'obs_data': load_data()
}

And we now define our simulator as before, using fixed \(T\) and \(N\). You already guessed this, that varying this variables will amount to nothing different than further context variables…

[42]:
simulator = Simulator(
    simulator_fun=partial(stationary_SIR, T=config['T'], N=config['N'])
)

Generative Model

[43]:
model = GenerativeModel(prior, simulator, name='alpha_covid_simulator')
INFO:root:Performing 2 pilot runs with the alpha_covid_simulator model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 5)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 14, 1)
INFO:root:No optional prior non-batchable context provided.
INFO:root:Could not determine shape of prior batchable context. Type appears to be non-array: <class 'list'>,                                    so make sure your input configurator takes cares of that!
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional simulation batchable context provided.
[44]:
model(32)
[44]:
{'prior_non_batchable_context': None,
 'prior_batchable_context': [array([0.79273277, 1.19422818, 1.8356426 , 1.65536612, 1.66737236]),
  array([1.63899447, 0.78282046, 1.93784751, 0.52241938, 0.51631685]),
  array([1.57341413, 0.53056712, 1.95345   , 0.62165196, 0.88358885]),
  array([1.7735128 , 1.93760023, 1.26856752, 1.82266109, 1.35856702]),
  array([1.7642493 , 0.50893892, 1.36037249, 0.73453132, 1.4185381 ]),
  array([1.31400376, 0.52861479, 1.03784459, 0.78415361, 0.9818053 ]),
  array([0.9626923 , 1.79112886, 1.83302327, 0.51940978, 1.68399211]),
  array([1.38447241, 0.96412365, 1.95212989, 0.73187025, 1.39411552]),
  array([1.73864575, 0.86433807, 0.58017974, 1.66084845, 1.88251074]),
  array([0.50834923, 1.78449442, 1.10429682, 1.83873072, 1.28765027]),
  array([1.4612476 , 1.44064022, 1.05470228, 0.74205375, 1.89993871]),
  array([0.53692873, 1.7255384 , 1.34814899, 1.33890243, 1.66728212]),
  array([0.55507505, 1.91271914, 1.7601826 , 1.4366108 , 0.99143708]),
  array([1.7363727 , 1.94404746, 1.24606924, 1.64436361, 1.32605177]),
  array([1.19001474, 1.72364784, 0.7571364 , 0.51278269, 0.98517823]),
  array([0.70912082, 0.97846952, 0.97789085, 1.25335047, 1.71318018]),
  array([1.561133  , 1.81205325, 0.62117482, 0.76486556, 1.25085506]),
  array([0.9153956 , 1.34123268, 1.66536784, 1.18155245, 1.31185419]),
  array([1.58072607, 1.49153524, 1.02329835, 0.60158568, 1.26436594]),
  array([0.54885802, 1.93349388, 1.26269767, 0.80025799, 0.67018478]),
  array([0.65062283, 0.75510436, 1.56262499, 0.51539649, 1.50702123]),
  array([0.72655852, 1.60896135, 1.61539548, 1.35312166, 0.68479443]),
  array([0.82312277, 1.13805441, 1.4357591 , 0.60865695, 1.50138526]),
  array([0.67371253, 1.68551019, 0.50195836, 0.98952698, 1.10839255]),
  array([0.59616488, 0.87993009, 1.16696925, 1.1575394 , 1.52839044]),
  array([1.40925058, 0.76338648, 1.84697278, 1.79234631, 0.72216377]),
  array([0.71679982, 1.62121116, 1.98863156, 1.19901313, 0.68982972]),
  array([0.76622924, 1.09068599, 1.74265555, 0.85182953, 0.97659312]),
  array([1.23821346, 0.5883582 , 1.44357305, 1.62515295, 1.90535686]),
  array([1.50328286, 1.869089  , 0.54773477, 1.63725377, 0.88249361]),
  array([1.91966603, 0.84828301, 1.61642387, 0.73959121, 0.69404266]),
  array([1.61294399, 1.92870438, 0.56384108, 1.54334376, 1.44812709])],
 'prior_draws': array([[4.64013552e-01, 1.86108103e-01, 6.40624525e+00, 2.60712680e+01,
         2.28771795e+00],
        [3.57316965e-01, 1.04070080e-01, 9.82143804e+00, 9.13290881e+00,
         2.78264920e+01],
        [3.98988532e-01, 1.47156776e-01, 7.97633507e+00, 5.86530716e+01,
         1.46262521e+01],
        [4.11281481e-01, 1.58254351e-01, 7.02923570e+00, 2.86856991e+01,
         1.29861493e+00],
        [2.66347542e-01, 1.11065110e-01, 8.85287956e+00, 1.60891672e+01,
         2.38097779e-01],
        [5.44546164e-01, 1.52254816e-01, 9.92608495e+00, 4.91941643e+01,
         3.98480767e+00],
        [2.42984414e-01, 1.36995535e-01, 7.18819983e+00, 1.17901976e+00,
         2.20802689e+00],
        [3.49357591e-01, 1.47552742e-01, 7.33676256e+00, 6.49177183e+01,
         1.69501168e+00],
        [4.60733215e-01, 1.26578025e-01, 6.83697731e+00, 1.08608620e+01,
         5.32095974e+00],
        [2.07239336e-01, 1.11499139e-01, 9.49131592e+00, 3.17556703e+01,
         4.00099726e+00],
        [4.57865951e-01, 1.43670491e-01, 8.19443426e+00, 2.76158553e+01,
         7.48902308e+00],
        [3.48766034e-01, 1.13697390e-01, 6.23146281e+00, 2.04618107e+01,
         1.96676654e+00],
        [4.47920298e-01, 9.33645322e-02, 6.73633681e+00, 2.34573019e+01,
         8.00400048e-01],
        [2.34712524e-01, 1.46690029e-01, 6.87645506e+00, 2.55977245e+01,
         5.62580242e-01],
        [4.20477921e-01, 1.65193158e-01, 7.33094183e+00, 9.54935679e+01,
         7.62615631e+00],
        [2.06968702e-01, 1.96726791e-01, 1.12803453e+01, 5.52843739e+01,
         7.16787017e-01],
        [2.57031500e-01, 1.11479484e-01, 6.04365037e+00, 8.21914859e+00,
         2.12034243e+00],
        [3.52434023e-01, 1.30163985e-01, 6.70672868e+00, 4.75099992e+01,
         2.47319798e+00],
        [4.39683711e-01, 1.18464173e-01, 9.06789380e+00, 1.07887112e+02,
         9.71049887e+00],
        [8.09018257e-01, 1.38853914e-01, 1.10543283e+01, 1.22983853e+01,
         2.03830325e+00],
        [3.71972370e-01, 1.18878471e-01, 9.78834669e+00, 4.90503186e+01,
         2.26918440e-01],
        [1.03141926e+00, 1.23849343e-01, 8.39898327e+00, 2.53016756e+01,
         2.44119811e-01],
        [2.78335657e-01, 1.26974497e-01, 8.19208776e+00, 2.86954568e+01,
         4.55377379e+00],
        [2.98353491e-01, 1.32482395e-01, 5.34847923e+00, 2.53719043e+01,
         5.31371441e+00],
        [2.56852258e-01, 1.20998145e-01, 8.10951131e+00, 5.35923178e+01,
         1.16220817e-01],
        [4.79468038e-01, 9.20171142e-02, 6.94595852e+00, 4.07514076e+01,
         6.86768545e+00],
        [7.56452310e-01, 1.09943554e-01, 7.41290310e+00, 3.47291918e+01,
         2.77781661e+00],
        [6.06987661e-01, 9.53063253e-02, 4.80580765e+00, 1.15499895e+01,
         8.86260599e-01],
        [7.60495314e-01, 1.02037761e-01, 6.60685571e+00, 3.98080057e+01,
         1.47510115e+00],
        [3.62318011e-01, 1.02132464e-01, 1.08624619e+01, 1.98415414e+01,
         5.09614999e-01],
        [3.12722391e-01, 1.04109665e-01, 8.27064447e+00, 1.36605262e+01,
         7.00399372e-01],
        [3.71681539e-01, 1.35283854e-01, 1.05783098e+01, 5.68456311e+01,
         6.71377591e+00]]),
 'sim_non_batchable_context': None,
 'sim_batchable_context': None,
 'sim_data': array([[[     65],
         [     59],
         [     41],
         [     77],
         [    431],
         [     53],
         [     63],
         [     81],
         [    276],
         [    603],
         [    815],
         [    309],
         [    587],
         [    393]],

        [[     19],
         [     31],
         [     68],
         [     47],
         [     71],
         [     96],
         [    118],
         [    101],
         [    147],
         [    224],
         [    242],
         [    271],
         [    431],
         [    552]],

        [[    196],
         [    170],
         [    237],
         [    220],
         [    316],
         [    274],
         [    359],
         [    514],
         [    499],
         [   1070],
         [    854],
         [    854],
         [   1607],
         [   1125]],

        [[     34],
         [      7],
         [     99],
         [    155],
         [     65],
         [     42],
         [    286],
         [    446],
         [    458],
         [    341],
         [     62],
         [    845],
         [    939],
         [   1160]],

        [[    199],
         [      2],
         [      0],
         [      0],
         [      0],
         [     16],
         [      0],
         [     35],
         [    146],
         [      0],
         [     11],
         [    192],
         [      3],
         [     59]],

        [[    631],
         [    788],
         [   1591],
         [   1537],
         [   1728],
         [   2078],
         [   4238],
         [   7347],
         [   8031],
         [   4575],
         [  10637],
         [  15657],
         [  21789],
         [  54596]],

        [[      1],
         [      1],
         [      0],
         [      1],
         [      1],
         [      0],
         [      1],
         [      0],
         [      0],
         [      2],
         [      4],
         [      5],
         [      0],
         [      5]],

        [[     18],
         [     73],
         [    139],
         [     50],
         [    261],
         [    178],
         [    124],
         [     22],
         [    119],
         [    402],
         [     35],
         [    452],
         [     54],
         [   1438]],

        [[     14],
         [     37],
         [     49],
         [     81],
         [    135],
         [    132],
         [    190],
         [    155],
         [    271],
         [    520],
         [    274],
         [    646],
         [    301],
         [   1594]],

        [[     20],
         [      6],
         [      9],
         [     14],
         [     18],
         [     14],
         [      9],
         [     14],
         [     16],
         [      6],
         [     17],
         [     64],
         [     35],
         [     43]],

        [[     61],
         [    146],
         [    171],
         [    190],
         [    208],
         [    178],
         [    584],
         [    402],
         [    579],
         [    538],
         [   1105],
         [   2786],
         [   1428],
         [   2553]],

        [[     33],
         [     48],
         [     90],
         [     56],
         [     37],
         [     28],
         [    129],
         [    205],
         [    136],
         [     75],
         [    191],
         [    345],
         [    201],
         [    523]],

        [[    189],
         [    209],
         [    325],
         [     61],
         [    351],
         [    971],
         [     30],
         [    839],
         [    368],
         [   2671],
         [     75],
         [     82],
         [     18],
         [    383]],

        [[      0],
         [      1],
         [      0],
         [      5],
         [      3],
         [      4],
         [     40],
         [      6],
         [     10],
         [      0],
         [     75],
         [     17],
         [     32],
         [      4]],

        [[    260],
         [    141],
         [    295],
         [    224],
         [    220],
         [    731],
         [    952],
         [    783],
         [    676],
         [   1175],
         [   2159],
         [   3436],
         [   2690],
         [   2090]],

        [[      2],
         [      4],
         [      5],
         [      7],
         [     67],
         [      4],
         [      1],
         [     21],
         [      7],
         [      1],
         [      4],
         [      1],
         [      1],
         [    103]],

        [[      4],
         [      4],
         [      2],
         [      7],
         [      1],
         [      5],
         [      7],
         [     14],
         [      5],
         [     10],
         [     23],
         [     10],
         [      9],
         [      7]],

        [[    106],
         [    124],
         [    103],
         [    261],
         [    135],
         [    172],
         [    358],
         [    486],
         [    226],
         [    634],
         [    683],
         [    220],
         [    539],
         [    836]],

        [[    537],
         [    428],
         [   1667],
         [    724],
         [   1068],
         [   2440],
         [   1994],
         [   2713],
         [   2312],
         [   3351],
         [  10179],
         [   6637],
         [  11109],
         [   8889]],

        [[   1209],
         [    946],
         [   1687],
         [   4500],
         [   2527],
         [  51677],
         [  60373],
         [ 172697],
         [  59939],
         [  54312],
         [ 172934],
         [ 351120],
         [ 211354],
         [1688718]],

        [[      0],
         [    201],
         [      0],
         [     54],
         [    747],
         [      1],
         [     29],
         [   1284],
         [    105],
         [      6],
         [  26763],
         [      0],
         [   9145],
         [    796]],

        [[      0],
         [    637],
         [   2028],
         [    355],
         [      5],
         [     74],
         [    111],
         [     98],
         [ 412235],
         [ 505812],
         [ 104365],
         [ 642601],
         [2283709],
         [      9]],

        [[     22],
         [     31],
         [     13],
         [     42],
         [      8],
         [     22],
         [     19],
         [     49],
         [     68],
         [     70],
         [    151],
         [    151],
         [    177],
         [    186]],

        [[     23],
         [     30],
         [     16],
         [     20],
         [      8],
         [     35],
         [     55],
         [     37],
         [     30],
         [     45],
         [     65],
         [     52],
         [     50],
         [     57]],

        [[      0],
         [      0],
         [     17],
         [      2],
         [      0],
         [      0],
         [      0],
         [      0],
         [      0],
         [     28],
         [     58],
         [      1],
         [      0],
         [      0]],

        [[     93],
         [    110],
         [    371],
         [    511],
         [    569],
         [    401],
         [    521],
         [   1553],
         [   2403],
         [   4062],
         [   6370],
         [   7030],
         [   5142],
         [   8303]],

        [[    235],
         [    732],
         [   2095],
         [   2204],
         [   2815],
         [   5885],
         [   8649],
         [   9178],
         [  34872],
         [  24155],
         [ 222351],
         [ 154588],
         [ 190584],
         [ 429498]],

        [[     37],
         [     12],
         [    109],
         [      2],
         [    276],
         [     44],
         [    796],
         [    376],
         [    987],
         [    554],
         [    708],
         [   7028],
         [   2105],
         [  11029]],

        [[    103],
         [     97],
         [   2024],
         [   8968],
         [  14304],
         [  16022],
         [   3503],
         [  45105],
         [  64522],
         [  12040],
         [ 403525],
         [  96459],
         [ 252234],
         [2110304]],

        [[     47],
         [     28],
         [      7],
         [     23],
         [    430],
         [     18],
         [    471],
         [     38],
         [     85],
         [     12],
         [    560],
         [      2],
         [    599],
         [   1054]],

        [[     21],
         [     32],
         [      3],
         [      8],
         [      1],
         [     69],
         [    165],
         [      8],
         [     74],
         [     11],
         [     54],
         [     17],
         [     34],
         [    138]],

        [[    214],
         [     90],
         [    193],
         [    539],
         [    266],
         [    474],
         [    643],
         [    912],
         [   1060],
         [    672],
         [   1553],
         [   1603],
         [   3160],
         [   2078]]], dtype=int64)}

Defining the Configurator

As you by now know, the configurator is the part that connects your simulator outputs to the neural inference architecture. Notice how this time the return dictionary has three keys: 1. summary_conditions - these will be passed through the summary network before going into the inference network; 2. direct_conditions - these will be passed directly to the inference network; 3. parameters - these are the quantities we wish to perform posterior inference on.

[45]:
def configure_input(forward_dict):
    """ Function to configure the simulated quantities (i.e., simulator outputs)
    into a neural network-friendly (BayesFlow) format.
    """

    # Prepare placeholder dict
    out_dict = {}

    # Convert data to logscale
    logdata = np.log1p(forward_dict['sim_data']).astype(np.float32)

    # Extract scaling parameter and convert to array
    alphas = np.array(forward_dict['prior_batchable_context']).astype(np.float32)

    # Extract prior draws and z-standardize with previously computed means
    params = forward_dict['prior_draws'].astype(np.float32)
    params = (params - prior_means) / prior_stds

    # Remove a batch if it contains nan, inf or -inf
    idx_keep = np.all(np.isfinite(logdata), axis=(1, 2))
    if not np.all(idx_keep):
        print('Invalid value encountered...removing from batch')

    # Add to keys
    out_dict['summary_conditions'] = logdata[idx_keep]
    out_dict['direct_conditions'] = alphas[idx_keep]
    out_dict['parameters'] = params[idx_keep]

    return out_dict

We can always do a quick check that our configurator works well with the outputs of the generative model:

[46]:
_ = configure_input(model(batch_size = 2))

Defining the Neural Approximator

[163]:
# TODO
[165]:
summary_net = SummaryNet(n_summary=16)
inference_net = InvertibleNetwork({'n_params': 5,
                                   'n_coupling_layers': 3}
amortizer = AmortizedPosterior(inference_net, summary_net)

Defining the Trainer

[166]:
# change var_obs
trainer = Trainer(amortizer=amortizer,
                  generative_model=model,
                  configurator=configure_input,
                  checkpoint_path='Alpha_Scaling_SIR')
INFO:root:Initializing networks from scratch.
INFO:root:Performing a consistency check with provided components...
INFO:root:Done.

Training Phase

[30]:
# h = trainer.train_online(epochs=20, iterations_per_epoch=1000, batch_size=32)

Validation

[ ]:

Inference Phase

[168]:
samples = amortizer.sample(
    {'summary_conditions': np.log1p(obs_data).astype(np.float32)[np.newaxis, :, np.newaxis],
     'direct_conditions' : np.ones((1, 5)).astype(np.float32)},
     n_samples=3000, to_numpy=True
)
samples = samples[np.sum(samples < 0, axis=1) == 0]
f = plot_median_predictions(samples, obs_data)
d:\anaconda3\envs\tensorflowdev\lib\site-packages\ipykernel_launcher.py:37: UserWarning: FixedFormatter should only be used together with FixedLocator
../_images/tutorial_notebooks_PriorSensitivity_Covid19_Initial_46_1.png
[114]:
samples = amortizer.sample(
    {'summary_conditions': np.log1p(obs_data).astype(np.float32)[np.newaxis, :, np.newaxis],
     'direct_conditions' : 2 * np.ones((1, 5)).astype(np.float32)},
     n_samples=1000, to_numpy=True
)
samples = samples[np.sum(samples < 0, axis=1) == 0]
f = plot_median_predictions(samples, obs_data)
d:\anaconda3\envs\tensorflowdev\lib\site-packages\ipykernel_launcher.py:37: UserWarning: FixedFormatter should only be used together with FixedLocator
../_images/tutorial_notebooks_PriorSensitivity_Covid19_Initial_47_1.png
[115]:
samples = amortizer.sample(
    {'summary_conditions': np.log1p(obs_data).astype(np.float32)[np.newaxis, :, np.newaxis],
     'direct_conditions' : 0.5 * np.ones((1, 5)).astype(np.float32)},
     n_samples=1000, to_numpy=True
)
samples = samples[np.sum(samples < 0, axis=1) == 0]
f = plot_median_predictions(samples, obs_data)
d:\anaconda3\envs\tensorflowdev\lib\site-packages\ipykernel_launcher.py:37: UserWarning: FixedFormatter should only be used together with FixedLocator
../_images/tutorial_notebooks_PriorSensitivity_Covid19_Initial_48_1.png
[169]:
param_names = [r'$\lambda$', r'$\mu$', r'$D$', r'$I_0$', r'$\psi$']
[225]:
samples_alpha = []
for alpha in [0.5, 1, 2]:

    samples = amortizer.sample(
        {'summary_conditions': np.log1p(obs_data).astype(np.float32)[np.newaxis, :, np.newaxis],
         'direct_conditions' : alpha * np.ones((1, 5)).astype(np.float32)},
         n_samples=5000, to_numpy=True
    )
    samples = pd.DataFrame(samples[np.sum(samples < 0, axis=1) == 0], columns=param_names)
    samples_alpha.append(samples)
[226]:
df_all = pd.concat(samples_alpha, axis=0)
[227]:
df_all['Scaling factor'] = [r'$\alpha = 0.5$'] * len(samples_alpha[0]) \
            + [r'$\alpha = 1.0$'] * len(samples_alpha[1]) \
            + [r'$\alpha = 2.0$'] * len(samples_alpha[2])
[228]:
def build_viridis_palette(n, n_total=20, base_palette="viridis"):
    """
    Builds a viridis palette with maximal entropy (evenly spaced)
    """
    color_palette = np.array(sns.color_palette(base_palette, n_colors=n_total))
    indices = np.array(np.floor(np.linspace(0, n_total-1, n)), dtype=np.int32)
    color_palette = color_palette[indices]
    return [tuple(c) for c in color_palette]
[229]:
colors = build_viridis_palette(21, base_palette="plasma")
[230]:
color_codes = {
    r'$\alpha = 0.5$' : colors[0],
    r'$\alpha = 1.0$' : colors[8],
    r'$\alpha = 2.0$' : colors[16],

}
sns.palplot(color_codes.values())
../_images/tutorial_notebooks_PriorSensitivity_Covid19_Initial_55_0.png
[262]:
def corrfunc(x, y, **kws):
    r, _ = stats.pearsonr(x, y)
    ax = plt.gca()
    ax.annotate(r"$\rho$ = {:.2f}".format(r),
                xy=(.25, .5), xycoords=ax.transAxes, fontsize=20)


plt.rcParams['font.size'] = 18
grid = sns.PairGrid(df_all, height = 2, hue='Scaling factor', palette=color_codes)
grid = grid.map_diag(sns.histplot, alpha=0.9, fill=True)
grid = grid.map_lower(sns.kdeplot, n_levels=10, cut=0, bw_method='silverman', alpha=0.8)
# grid = grid.map_upper(corrfunc)

for i, j in zip(*np.triu_indices_from(grid.axes, 1)):
    grid.axes[i, j].axis('off')

for i in range(len(param_names)):
    grid.axes[i, i].axvline(df_all.iloc[:, i].mean(), color='black', linestyle='dashed')

grid.add_legend()
plt.setp(grid._legend.get_title(), fontsize=20)
plt.setp(grid._legend.get_texts(), fontsize=18)
[262]:
[None, None, None, None, None, None]
../_images/tutorial_notebooks_PriorSensitivity_Covid19_Initial_56_1.png
[263]:
grid.savefig("Initial_Pairs.png", dpi=300)
[267]:
colors = ['#8c87d6', '#d68787', '#d6d387']
legends = [r'$\alpha = 0.5$', r'$\alpha = 1.0$', r'$\alpha = 2.0$']
f, axarr = plt.subplots(1, 5, figsize=(25, 5))
for samples, legend, color in zip(samples_alpha, legends, colors):
    for i, ax in enumerate(axarr):
        l = legend
        sns.kdeplot(samples.values[:, i], ax=ax, alpha=0.7, label=l, fill=True, color=color)
        sns.despine(ax=ax)

for i, ax in enumerate(axarr):
    if i == 0:
        ax.legend()
    ax.set_title(param_names[i])
    ax.set_xlabel('Parameter value')
f.tight_layout()
../_images/tutorial_notebooks_PriorSensitivity_Covid19_Initial_58_0.png
[268]:
f.savefig("Initial_Marginal.png", dpi=300)
[ ]: