Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/Development' into Development
Browse files Browse the repository at this point in the history
# Conflicts:
#	bayesflow/diagnostics.py
  • Loading branch information
marvinschmitt committed May 9, 2023
2 parents 717c941 + 7bd3b41 commit 7bfe011
Show file tree
Hide file tree
Showing 8 changed files with 2,172 additions and 734 deletions.
83 changes: 78 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ Welcome to our BayesFlow library for efficient simulation-based Bayesian workflo
For starters, check out some of our walk-through notebooks:

1. [Quickstart amortized posterior estimation](docs/source/tutorial_notebooks/Intro_Amortized_Posterior_Estimation.ipynb)
2. [Principled Bayesian workflow for cognitive models](docs/source/tutorial_notebooks/LCA_Model_Posterior_Estimation.ipynb)
3. [Posterior estimation for ODEs](docs/source/tutorial_notebooks/Linear_ODE_system.ipynb)
4. [Posterior estimation for SIR-like models](docs/source/tutorial_notebooks/Covid19_Initial_Posterior_Estimation.ipynb)
3. [Principled Bayesian workflow for cognitive models](docs/source/tutorial_notebooks/LCA_Model_Posterior_Estimation.ipynb)
4. [Posterior estimation for ODEs](docs/source/tutorial_notebooks/Linear_ODE_system.ipynb)
5. [Posterior estimation for SIR-like models](docs/source/tutorial_notebooks/Covid19_Initial_Posterior_Estimation.ipynb)
6. [Model comparison for cognitive models](docs/source/tutorial_notebooks/Model_Comparison_MPT.ipynb)
7. [Hierarchical model comparison for cognitive models](docs/source/tutorial_notebooks/Hierarchical_Model_Comparison_MPT.ipynb)

## Project Documentation

Expand Down Expand Up @@ -167,7 +169,7 @@ to the `AmortizedPosterior` instance:
amortizer = bf.amortizers.AmortizedPosterior(inference_net, summary_net, summary_loss_fun='MMD')
```

The amortizer knows how to combine its losses.
The amortizer knows how to combine its losses and you can inspect the summary space for outliers during inference.

### References and Further Reading

Expand All @@ -177,7 +179,74 @@ preprint</em>, available for free at: https://arxiv.org/abs/2112.08866

## Model Comparison

Example coming soon...
BayesFlow can not only be used for parameter estimation, but also to perform approximate Bayesian model comparison via posterior model probabilities or Bayes factors.
Let's extend the minimal example from before with a second model $M_2$ that we want to compare with our original model $M_1$:

```python
def simulator(theta, n_obs=50, scale=1.0):
return np.random.default_rng().normal(loc=theta, scale=scale, size=(n_obs, theta.shape[0]))

def prior_m1(D=2, mu=0., sigma=1.0):
return np.random.default_rng().normal(loc=mu, scale=sigma, size=D)

def prior_m2(D=2, mu=2., sigma=1.0):
return np.random.default_rng().normal(loc=mu, scale=sigma, size=D)
```

For the purpose of this illustration, the two toy models only differ with respect to their prior specification ($M_1: \mu = 0, M_2: \mu = 2$). We create both models as before and use a `MultiGenerativeModel` wrapper to combine them in a `meta_model`:

```python
model_m1 = bf.simulation.GenerativeModel(prior_m1, simulator, simulator_is_batched=False)
model_m2 = bf.simulation.GenerativeModel(prior_m2, simulator, simulator_is_batched=False)
meta_model = bf.simulation.MultiGenerativeModel([model_m1, model_m2])
```

Next, we construct our neural network with a `PMPNetwork` for approximating posterior model probabilities:

```python
summary_net = bf.networks.DeepSet()
probability_net = bf.networks.PMPNetwork(num_models=2)
amortizer = bf.amortizers.AmortizedModelComparison(probability_net, summary_net)
```

We combine all previous steps with a `Trainer` instance and train the neural approximator:

```python
trainer = bf.trainers.Trainer(amortizer=amortizer, generative_model=meta_model)
losses = trainer.train_online(epochs=3, iterations_per_epoch=100, batch_size=32)
```

Let's simulate data sets from our models to check our networks' performance:

```python
sims = trainer.configurator(meta_model(5000))
```

When feeding the data to our trained network, we almost immediately obtain posterior model probabilities for each of the 5000 data sets:

```python
model_probs = amortizer.posterior_probs(sims)
```

How good are these predicted probabilities in the closed world? We can have a look at the calibration:

```python
cal_curves = bf.diagnostics.plot_calibration_curves(sims["model_indices"], model_probs)
```

<img src="img/showcase_calibration_curves.png" width=65% height=65%>

Our approximator shows excellent calibration, with the calibration curve being closely aligned to the diagonal, an expected calibration error (ECE) near 0 and most predicted probabilities being certain of the model underlying a data set. We can further assess patterns of misclassification with a confusion matrix:

```python
conf_matrix = bf.diagnostics.plot_confusion_matrix(sims["model_indices"], model_probs)
```

<img src="img/showcase_confusion_matrix.png" width=44% height=44%>

For the vast majority of simulated data sets, the "true" data-generating model is correctly identified. With these diagnostic results backing us up, we can proceed and apply our trained network to empirical data.

BayesFlow is also able to conduct model comparison for hierarchical models. See this [tutorial notebook](docs/source/tutorial_notebooks/Hierarchical_Model_Comparison_MPT.ipynb) for an introduction to the associated workflow.

### References and Further Reading

Expand All @@ -190,6 +259,10 @@ doi:10.1109/TNNLS.2021.3124052 available for free at: https://arxiv.org/abs/2004
Bayesian Model Comparison. <em>ArXiv preprint</em>, available for free at:
https://arxiv.org/abs/2210.07278

- Elsemüller, L., Schnuerch, M., Bürkner, P. C., & Radev, S. T. (2023). A Deep
Learning Method for Comparing Bayesian Hierarchical Models. <em>ArXiv preprint</em>,
available for free at: https://arxiv.org/abs/2301.11873

## Likelihood emulation

Example coming soon...
29 changes: 22 additions & 7 deletions bayesflow/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ def plot_calibration_curves(
Returns
-------
f : plt.Figure - the figure instance for optional saving
fig : plt.Figure - the figure instance for optional saving
"""

num_models = true_models.shape[-1]
Expand All @@ -1051,7 +1051,7 @@ def plot_calibration_curves(
# Initialize figure
if fig_size is None:
fig_size = (int(5 * n_col), int(5 * n_row))
f, axarr = plt.subplots(n_row, n_col, figsize=fig_size)
fig, axarr = plt.subplots(n_row, n_col, figsize=fig_size)
if n_row > 1:
ax = axarr.flat

Expand Down Expand Up @@ -1096,8 +1096,8 @@ def plot_calibration_curves(

# Set title
ax[j].set_title(model_names[j], fontsize=title_fontsize)
f.tight_layout()
return f
fig.tight_layout()
return fig


def plot_confusion_matrix(
Expand All @@ -1107,6 +1107,8 @@ def plot_confusion_matrix(
fig_size=(5, 5),
title_fontsize=18,
tick_fontsize=12,
xtick_rotation=None,
ytick_rotation=None,
normalize=True,
cmap=None,
title=True,
Expand All @@ -1124,9 +1126,13 @@ def plot_confusion_matrix(
fig_size : tuple or None, optional, default: (5, 5)
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
title_fontsize : int, optional, default: 18
The font size of the axis label texts.
The font size of the title text.
tick_fontsize : int, optional, default: 12
The font size of the axis label texts.
The font size of the axis label and model name texts.
xtick_rotation: int, optional, default: None
Rotation of x-axis tick labels (helps with long model names).
ytick_rotation: int, optional, default: None
Rotation of y-axis tick labels (helps with long model names).
normalize : bool, optional, default: True
A flag for normalization of the confusion matrix.
If True, each row of the confusion matrix is normalized to sum to 1.
Expand All @@ -1135,6 +1141,10 @@ def plot_confusion_matrix(
e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red.
title : bool, optional, default True
A flag for adding 'Confusion Matrix' above the matrix.
Returns
-------
fig : plt.Figure - the figure instance for optional saving
"""

if model_names is None:
Expand All @@ -1154,14 +1164,18 @@ def plot_confusion_matrix(
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

# Initialize figure
f, ax = plt.subplots(1, 1, figsize=fig_size)
fig, ax = plt.subplots(1, 1, figsize=fig_size)

im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
ax.figure.colorbar(im, ax=ax, shrink=0.7)

ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]))
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
if xtick_rotation:
plt.xticks(rotation=xtick_rotation, ha="right")
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
if ytick_rotation:
plt.yticks(rotation=ytick_rotation)
ax.set_xlabel("Predicted model", fontsize=tick_fontsize)
ax.set_ylabel("True model", fontsize=tick_fontsize)

Expand All @@ -1175,6 +1189,7 @@ def plot_confusion_matrix(
)
if title:
ax.set_title("Confusion Matrix", fontsize=title_fontsize)
return fig


def plot_mmd_hypothesis_test(mmd_null,
Expand Down
Loading

0 comments on commit 7bfe011

Please sign in to comment.