Skip to content

Commit

Permalink
Merge pull request #71 from elseml/Development
Browse files Browse the repository at this point in the history
Add model comparison tutorial notebooks
  • Loading branch information
stefanradev93 committed May 9, 2023
2 parents 168d961 + 4473772 commit a0dc5de
Show file tree
Hide file tree
Showing 7 changed files with 2,160 additions and 3 deletions.
77 changes: 76 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ For starters, check out some of our walk-through notebooks:
2. [Principled Bayesian workflow for cognitive models](docs/source/tutorial_notebooks/LCA_Model_Posterior_Estimation.ipynb)
3. [Posterior estimation for ODEs](docs/source/tutorial_notebooks/Linear_ODE_system.ipynb)
4. [Posterior estimation for SIR-like models](docs/source/tutorial_notebooks/Covid19_Initial_Posterior_Estimation.ipynb)
5. [Model comparison for cognitive models](docs/source/tutorial_notebooks/Model_Comparison_MPT.ipynb)
6. [Hierarchical model comparison for cognitive models](docs/source/tutorial_notebooks/Hierarchical_Model_Comparison_MPT.ipynb)

## Project Documentation

Expand Down Expand Up @@ -177,7 +179,76 @@ preprint</em>, available for free at: https://arxiv.org/abs/2112.08866

## Model Comparison

Example coming soon...
BayesFlow can not only be used for parameter estimation, but also to approximate Bayesian model comparison via posterior model probabilities or Bayes factors.

Let's extend the minimal example from before with a second model $M_2$ that we want to compare with our original model $M_1$:

```python
def simulator(theta, n_obs=50, scale=1.0):
return np.random.default_rng().normal(loc=theta, scale=scale, size=(n_obs, theta.shape[0]))

def prior_m1(D=2, mu=0., sigma=1.0):
return np.random.default_rng().normal(loc=mu, scale=sigma, size=D)

def prior_m2(D=2, mu=2., sigma=1.0):
return np.random.default_rng().normal(loc=mu, scale=sigma, size=D)
```

We create both models as before and use a `MultiGenerativeModel` wrapper to combine them in a `meta_model`:

```python
model_m1 = bf.simulation.GenerativeModel(prior_m1, simulator, simulator_is_batched=False)
model_m2 = bf.simulation.GenerativeModel(prior_m2, simulator, simulator_is_batched=False)
meta_model = bf.simulation.MultiGenerativeModel([model_m1, model_m2])
```

Next, we construct our neural network with a `PMPNetwork` for approximating posterior model probabilities:

```python
summary_net = bf.networks.DeepSet()
probability_net = bf.networks.PMPNetwork(num_models=2)
amortizer = bf.amortizers.AmortizedModelComparison(probability_net, summary_net)
```

We combine all previous steps with a `Trainer` instance and train the neural approximator:

```python
trainer = bf.trainers.Trainer(amortizer=amortizer, generative_model=meta_model)
losses = trainer.train_online(epochs=3, iterations_per_epoch=100, batch_size=32)
```

Let's simulate data sets from our models to check our networks' performance:

```python
sim_data = trainer.configurator(meta_model(5000))
sim_indices = sim_data["model_indices"]
```

When feeding the data to our trained network, we almost immediately obtain posterior model probabilities for each of the 5000 data sets:

```python
sim_preds = amortizer(sim_data)
```

How good are these predicted probabilities? We can have a look at the calibration:

```python
cal_curves = bf.diagnostics.plot_calibration_curves(sim_indices, sim_preds)
```

<img src="img/showcase_calibration_curves.png" width=65% height=65%>

Our approximator shows excellent calibration, with the calibration curve being closely aligned to the diagonal, an expected calibration error (ECE) near 0 and most predicted probabilities being certain of the model underlying a data set. We can further assess patterns of misclassification with a confusion matrix:

```python
conf_matrix = bf.diagnostics.plot_confusion_matrix(sim_indices, sim_preds)
```

<img src="img/showcase_confusion_matrix.png" width=44% height=44%>

For the vast majority of simulated data sets, the generating model is correctly detected. With these diagnostic results backing us up, we can safely apply our trained network to empirical data.

BayesFlow is also able to conduct model comparison for hierarchical models. See this [tutorial notebook](docs/source/tutorial_notebooks/Hierarchical_Model_Comparison_MPT.ipynb) for an introduction to the associated workflow.

### References and Further Reading

Expand All @@ -190,6 +261,10 @@ doi:10.1109/TNNLS.2021.3124052 available for free at: https://arxiv.org/abs/2004
Bayesian Model Comparison. <em>ArXiv preprint</em>, available for free at:
https://arxiv.org/abs/2210.07278

- Elsemüller, L., Schnuerch, M., Bürkner, P. C., & Radev, S. T. (2023). A Deep
Learning Method for Comparing Bayesian Hierarchical Models. <em>ArXiv preprint</em>,
available for free at: https://arxiv.org/abs/2301.11873

## Likelihood emulation

Example coming soon...
14 changes: 12 additions & 2 deletions bayesflow/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,8 @@ def plot_confusion_matrix(
fig_size=(5, 5),
title_fontsize=18,
tick_fontsize=12,
xtick_rotation=None,
ytick_rotation=None,
normalize=True,
cmap=None,
title=True,
Expand All @@ -1124,9 +1126,13 @@ def plot_confusion_matrix(
fig_size : tuple or None, optional, default: (5, 5)
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
title_fontsize : int, optional, default: 18
The font size of the axis label texts.
The font size of the title text.
tick_fontsize : int, optional, default: 12
The font size of the axis label texts.
The font size of the axis label and model name texts.
xtick_rotation: int, optional, default: None
Rotation of x-axis tick labels (helps with long model names).
ytick_rotation: int, optional, default: None
Rotation of y-axis tick labels (helps with long model names).
normalize : bool, optional, default: True
A flag for normalization of the confusion matrix.
If True, each row of the confusion matrix is normalized to sum to 1.
Expand Down Expand Up @@ -1165,7 +1171,11 @@ def plot_confusion_matrix(

ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]))
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
if xtick_rotation:
plt.xticks(rotation=xtick_rotation, ha="right")
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
if ytick_rotation:
plt.yticks(rotation=ytick_rotation)
ax.set_xlabel("Predicted model", fontsize=tick_fontsize)
ax.set_ylabel("True model", fontsize=tick_fontsize)

Expand Down
958 changes: 958 additions & 0 deletions docs/source/tutorial_notebooks/Hierarchical_Model_Comparison_MPT.ipynb

Large diffs are not rendered by default.

1,114 changes: 1,114 additions & 0 deletions docs/source/tutorial_notebooks/Model_Comparison_MPT.ipynb

Large diffs are not rendered by default.

Binary file added docs/source/tutorial_notebooks/img/1HT2HT.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/showcase_calibration_curves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/showcase_confusion_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit a0dc5de

Please sign in to comment.