Skip to content

Commit

Permalink
Merge pull request #69 from elseml/Development
Browse files Browse the repository at this point in the history
Add plot_confusion_matrix for model comparison
  • Loading branch information
stefanradev93 committed Apr 18, 2023
2 parents fe2a055 + f950a48 commit c1ac1c3
Showing 1 changed file with 88 additions and 7 deletions.
95 changes: 88 additions & 7 deletions bayesflow/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D
from scipy.stats import binom, median_abs_deviation
from sklearn.metrics import r2_score
from sklearn.metrics import confusion_matrix, r2_score

logging.basicConfig()

Expand Down Expand Up @@ -992,10 +993,11 @@ def plot_latent_space_2d(z_samples, height=2.5, color="#8f2727", **kwargs):


def plot_calibration_curves(
true_models, pred_models, model_names=None, num_bins=10, font_size=12, fig_size=(12, 4), color="#8f2727"
true_models, pred_models, model_names=None, num_bins=10, font_size=12, fig_size=None, color="#8f2727"
):
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
for a model comparison problem. Depends on the ``expected_calibration_error`` function for computing the ECE.
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
Depends on the ``expected_calibration_error`` function for computing the ECE.
Parameters
----------
Expand All @@ -1009,8 +1011,8 @@ def plot_calibration_curves(
The number of bins to use for the calibration curves (and marginal histograms).
font_size : int, optional, default: 12
The font size of the axis label texts.
fig_size : tuple, optional, default: (12, 4)
The figure size passed to the matplotlib constructor.
fig_size : tuple or None, optional, default: None
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
color : str, optional, default: '#8f2727'
The color of the plot.
Expand All @@ -1026,10 +1028,11 @@ def plot_calibration_curves(
# Determine n_subplots dynamically
n_row = int(np.ceil(num_models / 6))
n_col = int(np.ceil(num_models / n_row))

cal_errs, cal_probs = expected_calibration_error(true_models, pred_models, num_bins)

# Initialize figure
if fig_size is None:
fig_size = (int(5 * n_col), int(5 * n_row))
f, axarr = plt.subplots(n_row, n_col, figsize=fig_size)
if n_row > 1:
ax = axarr.flat
Expand All @@ -1049,7 +1052,7 @@ def plot_calibration_curves(
# Plot PMP distribution over bins
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
norm_weights = np.ones_like(pred_models) / len(pred_models)
ax[j].hist(pred_models[:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.1)
ax[j].hist(pred_models[:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)

# Tweak plot
ax[j].spines["right"].set_visible(False)
Expand All @@ -1075,3 +1078,81 @@ def plot_calibration_curves(
ax[j].set_title(model_names[j])
f.tight_layout()
return f


def plot_confusion_matrix(
true_models,
pred_models,
model_names=None,
fig_size=(5, 5),
title_fontsize=18,
tick_fontsize=12,
normalize=True,
cmap=None,
title=True,
):
"""Plots a confusion matrix for validating a neural network trained on Bayesian model comparison.
Parameters
----------
true_models : np.ndarray of shape (num_data_sets, num_models)
The one-hot-encoded true model indices per data set.
pred_models : np.ndarray of shape (num_data_sets, num_models)
The predicted posterior model probabilities (PMPs) per data set.
model_names : list or None, optional, default: None
The model names for nice plot titles. Inferred if None.
fig_size : tuple or None, optional, default: (5, 5)
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
title_fontsize : int, optional, default: 18
The font size of the axis label texts.
tick_fontsize : int, optional, default: 12
The font size of the axis label texts.
normalize : bool, optional, default: True
A flag for normalization of the confusion matrix.
If True, each row of the confusion matrix is normalized to sum to 1.
cmap : matplotlib.colors.Colormap or str, optional, default: None
Colormap to be used for the cells. If a str, it should be the name of a registered colormap,
e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red.
title : bool, optional, default True
A flag for adding 'Confusion Matrix' above the matrix.
"""

if model_names is None:
num_models = true_models.shape[-1]
model_names = [rf"$M_{{{m}}}$" for m in range(1, num_models + 1)]

if cmap is None:
cmap = LinearSegmentedColormap.from_list("", ["white", "#8f2727"])

# Flatten input
true_models = np.argmax(true_models, axis=1)
pred_models = np.argmax(pred_models, axis=1)

# Compute confusion matrix
cm = confusion_matrix(true_models, pred_models)

if normalize:
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

# Initialize figure
f, ax = plt.subplots(1, 1, figsize=fig_size)

im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
ax.figure.colorbar(im, ax=ax, shrink=0.7)

ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]))
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
ax.set_xlabel("Predicted model", fontsize=tick_fontsize)
ax.set_ylabel("True model", fontsize=tick_fontsize)

# Loop over data dimensions and create text annotations
fmt = ".2f" if normalize else "d"
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(
j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black"
)
if title:
ax.set_title("Confusion Matrix", fontsize=title_fontsize)

0 comments on commit c1ac1c3

Please sign in to comment.