Source code for bayesflow.diagnostics

# Copyright (c) 2022 The BayesFlow Developers

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from scipy.stats import binom
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import pandas as pd
import seaborn as sns

import logging
logging.basicConfig()

from bayesflow.computational_utilities import expected_calibration_error, simultaneous_ecdf_bands
from bayesflow.helper_classes import LossHistory
from bayesflow.helper_functions import check_posterior_prior_shapes


[docs]def plot_recovery(post_samples, prior_samples, point_agg=np.mean, uncertainty_agg=np.std, param_names=None, fig_size=None, label_fontsize=14, title_fontsize=16, metric_fontsize=16, add_corr=True, add_r2=True, color='#8f2727', n_col=None, n_row=None): """Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty. The point estimate can be controlled with the `point_agg` argument, and the uncertainty estimate can be controlled with the `uncertainty_agg` argument. This plot yields the same information as the "posterior z-score": https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html Important: Posterior aggregates play no special role in Bayesian inference and should only be used heuristically. For instanec, in the case of multi-modal posteriors, common point estimates, such as mean, (geometric) median, or maximum a posteriori (MAP) mean nothing. Parameters ---------- post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) The posterior draws obtained from n_data_sets prior_samples : np.ndarray of shape (n_data_sets, n_params) The prior draws (true parameters) obtained for generating the n_data_sets point_agg : callable, optional, default: np.mean The function to apply to the posterior draws to get a point estimate for each marginal. uncertainty_agg : callable or None, optional, default: np.std The function to apply to the posterior draws to get an uncertainty estimate. If `None` provided, a simple scatter will be plotted. param_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None fig_size : tuple or None, optional, default : None The figure size passed to the matplotlib constructor. Inferred if None. label_fontsize : int, optional, default: 14 The font size of the y-label text title_fontsize : int, optional, default: 16 The font size of the title text metric_fontsize : int, optional, default: 16 The font size of the goodness-of-fit metric (if provided) add_corr : boolean, optional, default: True A flag for adding correlation between true and estimates to the plot. add_r2 : boolean, optional, default: True A flag for adding R^2 between true and estimates to the plot. color : str, optional, default: '#8f2727' The color for the true vs. estimated scatter points and errobars. Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ ShapeError If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. """ # Sanity check check_posterior_prior_shapes(post_samples, prior_samples) # Compute point estimates and uncertainties est = point_agg(post_samples, axis=1) if uncertainty_agg is not None: u = uncertainty_agg(post_samples, axis=1) # Determine n params and param names if None given n_params = prior_samples.shape[-1] if param_names is None: param_names = [f'$p_{i}$' for i in range(1, n_params+1)] # Determine number of rows and columns for subplots based on inputs if n_row is None and n_col is None: n_row = int(np.ceil(n_params / 6)) n_col = int(np.ceil(n_params / n_row)) elif n_row is None and n_col is not None: n_row = int(np.ceil(n_params / n_col)) elif n_row is not None and n_col is None: n_col = int(np.ceil(n_params / n_row)) # Initialize figure if fig_size is None: fig_size = (int(4 * n_col), int(4 * n_row)) f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) # turn axarr into 1D list if n_col > 1 or n_row > 1: axarr = axarr.flat else: # for 1x1, axarr is not a list -> turn it into one for use with enumerate axarr = [axarr] for i, ax in enumerate(axarr): if i >= n_params: break # Add scatter and errorbars if uncertainty_agg is not None: im = ax.errorbar(prior_samples[:, i], est[:, i], yerr=u[:, i], fmt='o', alpha=0.5, color=color) else: im = ax.scatter(prior_samples[:, i], est[:, i], alpha=0.5, color=color) # Make plots quadratic to avoid visual illusions lower = min(prior_samples[:, i].min(), est[:, i].min()) upper = max(prior_samples[:, i].max(), est[:, i].max()) eps = (upper - lower) * 0.1 ax.set_xlim([lower - eps, upper + eps]) ax.set_ylim([lower - eps, upper + eps]) ax.plot([ax.get_xlim()[0], ax.get_xlim()[1]], [ax.get_ylim()[0], ax.get_ylim()[1]], color='black', alpha=0.9, linestyle='dashed') # Add labels, optional metrics and title ax.set_xlabel('Ground truth', fontsize=label_fontsize) ax.set_ylabel('Estimated', fontsize=label_fontsize) if add_r2: r2 = r2_score(prior_samples[:, i], est[:, i]) ax.text(0.1, 0.9, '$R^2$ = {:.3f}'.format(r2), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes, size=metric_fontsize) if add_corr: corr = np.corrcoef(prior_samples[:, i], est[:, i])[0, 1] ax.text(0.1, 0.8, '$r$ = {:.3f}'.format(corr), horizontalalignment='left', verticalalignment='center', transform=ax.transAxes, size=metric_fontsize) ax.set_title(param_names[i], fontsize=title_fontsize) # Prettify sns.despine(ax=ax) ax.grid(alpha=0.5) f.tight_layout() return f
[docs]def plot_sbc_ecdf(post_samples, prior_samples, difference=False, stacked=False, fig_size=None, param_names=None, label_fontsize=14, legend_fontsize=14, title_fontsize=16, rank_ecdf_color='#a34f4f', fill_color='grey', **kwargs): """Creates the empirical CDFs for each marginal rank distribution and plots it against a uniform ECDF. ECDF simultaneous bands are drawn using simulations from the uniform. Inspired by: [1] Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test for discrete uniformity and its applications in goodness-of-fit evaluation and multiple sample comparison. Statistics and Computing, 32(2), 1-21. https://arxiv.org/abs/2103.10522 For models with many parameters, use `stacked=True` to obtain an idea of the overall calibration of a posterior approximator. Parameters ---------- post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) The posterior draws obtained from n_data_sets prior_samples : np.ndarray of shape (n_data_sets, n_params) The prior draws obtained for generating n_data_sets difference : boolean, optional, default: False If `True`, plots the ECDF difference. Enables a more dynamic visualization range. stacked : boolean, optional, default: False If `True`, all ECDFs will be plotted on the same plot. If `False`, each ECDF will have its own subplot, similar to the behavior of `plot_sbc_histograms`. param_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None. Only relevant if `stacked=False`. fig_size : tuple or None, optional, default: None The figure size passed to the matplotlib constructor. Inferred if None. label_fontsize : int, optional, default: 14 The font size of the y-label and y-label texts legend_fontsize : int, optional, default: 14 The font size of the legend text title_fontsize : int, optional, default: 16 The font size of the title text. Only relevant if `stacked=False` rank_ecdf_color : str, optional, default: '#a34f4f' The color to use for the rank ECDFs fill_color : str, optional, default: 'grey' The color of the fill arguments. **kwargs : dict, optional, default: {} Keyword arguments can be passed to control the behavior of ECDF simultaneous band computation through the `ecdf_bands_kwargs` dictionary. See `simultaneous_ecdf_bands` for keyword arguments Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ ShapeError If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. """ # Sanity checks check_posterior_prior_shapes(post_samples, prior_samples) # Store reference to number of parameters n_params = post_samples.shape[-1] # Compute fractional ranks (using broadcasting) ranks = np.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1) / post_samples.shape[1] # Prepare figure if stacked: f, ax = plt.subplots(1, 1, figsize=fig_size) else: # Determine n_subplots dynamically n_row = int(np.ceil(n_params / 6)) n_col = int(np.ceil(n_params / n_row)) # Determine fig_size dynamically, if None if fig_size is None: fig_size = (int(5*n_col), int(5*n_row)) # Initialize figure f, ax = plt.subplots(n_row, n_col, figsize=fig_size) # Plot individual ecdf of parameters for j in range(ranks.shape[-1]): ecdf_single = np.sort(ranks[:, j]) xx = ecdf_single yy = np.arange(1, xx.shape[-1]+1)/float(xx.shape[-1]) # Difference, if specified if difference: yy -= xx if stacked: if j == 0: ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label='Rank ECDFs') else: ax.plot(xx, yy, color=rank_ecdf_color, alpha=0.95) else: ax.flat[j].plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label='Rank ECDF') # Compute uniform ECDF and bands alpha, z, L, H = simultaneous_ecdf_bands(post_samples.shape[0], **kwargs.pop('ecdf_bands_kwargs', {})) # Difference, if specified if difference: L -= z H -= z # Add simultaneous bounds if stacked: titles = [None] axes = [ax] else: axes = ax.flat if param_names is None: titles = [f'$p_{i}$' for i in range(1, n_params+1)] else: titles = param_names for _ax, title in zip(axes, titles): _ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=fr'{int((1-alpha) * 100)}$\%$ Confidence Bands') # Prettify plot sns.despine(ax=_ax) _ax.grid(alpha=0.35) _ax.legend(fontsize=legend_fontsize) _ax.set_xlabel('Fractional rank statistic', fontsize=label_fontsize) if difference: ylab = 'ECDF difference' else: ylab = 'ECDF' _ax.set_ylabel(ylab, fontsize=label_fontsize) _ax.set_title(title, fontsize=title_fontsize) f.tight_layout() return f
[docs]def plot_sbc_histograms(post_samples, prior_samples, param_names=None, fig_size=None, num_bins=None, binomial_interval=0.99, label_fontsize=14, title_fontsize=16, hist_color='#a34f4f'): """Creates and plots publication-ready histograms of rank statistics for simulation-based calibration (SBC) checks according to: [1] Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). Validating Bayesian inference algorithms with simulation-based calibration. arXiv preprint arXiv:1804.06788. Any deviation from uniformity indicates miscalibration and thus poor convergence of the networks or poor combination between generative model / networks. Parameters ---------- post_samples : np.ndarray of shape (n_data_sets, n_post_draws, n_params) The posterior draws obtained from n_data_sets prior_samples : np.ndarray of shape (n_data_sets, n_params) The prior draws obtained for generating n_data_sets param_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None fig_size : tuple or None, optional, default : None The figure size passed to the matplotlib constructor. Inferred if None. num_bins : int, optional, default: 10 The number of bins to use for each marginal histogram binomial_interval : float in (0, 1), optional, default: 0.95 The width of the confidence interval for the binomial distribution label_fontsize : int, optional, default: 14 The font size of the y-label text title_fontsize : int, optional, default: 16 The font size of the title text hist_color : str, optional, default '#a34f4f' The color to use for the histogram body Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ ShapeError If there is a deviation form the expected shapes of `post_samples` and `prior_samples`. """ # Sanity check check_posterior_prior_shapes(post_samples, prior_samples) # Determine the ratio of simulations to prior draws n_sim, n_draws, n_params = post_samples.shape ratio = int(n_sim / n_draws) # Log a warning if N/B ratio recommended by Talts et al. (2018) < 20 if ratio < 20: logger = logging.getLogger() logger.setLevel(logging.INFO) logger.info(f'The ratio of simulations / posterior draws should be > 20 ' + f'for reliable variance reduction, but your ratio is {ratio}.\ Confidence intervals might be unreliable!') # Set n_bins automatically, if nothing provided if num_bins is None: num_bins = int(ratio / 2) # Attempt a fix if a single bin is determined so plot still makes sense if num_bins == 1: num_bins = 5 # Determine n params and param names if None given if param_names is None: param_names = [f'$p_{i}$' for i in range(1, n_params+1)] # Determine n_subplots dynamically n_row = int(np.ceil(n_params / 6)) n_col = int(np.ceil(n_params / n_row)) # Initialize figure if fig_size is None: fig_size = (int(5 * n_col), int(5 * n_row)) f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) # Compute ranks (using broadcasting) ranks = np.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1) # Compute confidence interval and mean N = int(prior_samples.shape[0]) # uniform distribution expected -> for all bins: equal probability # p = 1 / num_bins that a rank lands in that bin endpoints = binom.interval(binomial_interval, N, 1 / num_bins) mean = N / num_bins # corresponds to binom.mean(N, 1 / num_bins) # Plot marginal histograms in a loop if n_row > 1: ax = axarr.flat else: ax = axarr for j in range(len(param_names)): ax[j].axhspan(endpoints[0], endpoints[1], facecolor='gray', alpha=0.3) ax[j].axhline(mean, color='gray', zorder=0, alpha=0.5) sns.histplot(ranks[:, j], kde=False, ax=ax[j], color=hist_color, bins=num_bins, alpha=0.95) ax[j].set_title(param_names[j], fontsize=title_fontsize) ax[j].spines['right'].set_visible(False) ax[j].spines['top'].set_visible(False) ax[j].set_xlabel('Rank statistic', fontsize=label_fontsize) ax[j].get_yaxis().set_ticks([]) ax[j].set_ylabel('') f.tight_layout() return f
[docs]def plot_posterior_2d(posterior_draws, prior=None, prior_draws=None, param_names=None, height=3, legend_fontsize=14, post_color='#8f2727', prior_color='gray', post_alpha=0.9, prior_alpha=0.7): """Generates a bivariate pairplot given posterior draws and optional prior or prior draws. posterior_draws : np.ndarray of shape (n_post_draws, n_params) The posterior draws obtained for a SINGLE observed data set. prior : bayesflow.forward_inference.Prior instance or None, optional, default: None The optional prior object having an input-output signature as given by ayesflow.forward_inference.Prior prior_draws : np.ndarray of shape (n_prior_draws, n_params) or None, optonal (default: None) The optional prior draws obtained from the prior. If both prior and prior_draws are provided, prior_draws will be used. param_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None height : float, optional, default: 3. The height of the pairplot. legend_fontsize : int, optional, default: 14 The font size of the legend text. post_color : str, optional, default: '#8f2727' The color for the posterior histograms and KDEs. priors_color : str, optional, default: gray The color for the optional prior histograms and KDEs. post_alpha : float in [0, 1], optonal, default: 0.9 The opacity of the posterior plots. prior_alpha : float in [0, 1], optonal, default: 0.7 The opacity of the prior plots. Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ AssertionError If the shape of posterior_draws is not 2-dimensional. """ # Ensure correct shape assert(len(posterior_draws.shape)) == 2, 'Shape of `posterior_samples` for a single data set should be 2 dimensional!' # Obtain n_draws and n_params n_draws, n_params = posterior_draws.shape # If prior object is given and no draws, obtain draws if prior is not None and prior_draws is None: draws = prior(n_draws) if type(draws) is dict: prior_draws = draws['prior_draws'] else: prior_draws = draws # Otherwise, keep as is (prior_draws either filled or None) else: pass # Attempt to determine parameter names if param_names is None: if hasattr(prior, 'param_names'): if prior.param_names is not None: param_names = prior.param_names else: param_names = [f'$p_{i}$' for i in range(1, n_params+1)] else: param_names = [f'$p_{i}$' for i in range(1, n_params+1)] # Pack posterior draws into a dataframe posterior_draws_df = pd.DataFrame(posterior_draws, columns=param_names) # Add posterior g = sns.PairGrid(posterior_draws_df, height=height) g.map_diag(sns.histplot, fill=True, color=post_color, alpha=post_alpha, kde=True) g.map_lower(sns.kdeplot, fill=True, color=post_color, alpha=post_alpha) # Add prior, if given if prior_draws is not None: prior_draws_df = pd.DataFrame(prior_draws, columns=param_names) g.data = prior_draws_df g.map_diag(sns.histplot, fill=True, color=prior_color, alpha=prior_alpha, kde=True, zorder=-1) g.map_lower(sns.kdeplot, fill=True, color=prior_color, alpha=prior_alpha, zorder=-1) # Add legend, if prior also given if prior_draws is not None or prior is not None: handles = [Line2D(xdata=[], ydata=[], color=post_color, lw=3, alpha=post_alpha), Line2D(xdata=[], ydata=[], color=prior_color, lw=3, alpha=prior_alpha)] g.fig.legend(handles, ['Posterior', 'Prior'], fontsize=legend_fontsize, loc='center right') # Remove upper axis for i, j in zip(*np.triu_indices_from(g.axes, 1)): g.axes[i, j].axis('off') # Add grids for i in range(n_params): for j in range(n_params): g.axes[i, j].grid(alpha=0.5) g.tight_layout() return g.fig
[docs]def plot_losses(history, fig_size=None, color='#8f2727', label_fontsize=14, title_fontsize=16): """A generic helper function to plot the losses of a series of training epochs and runs. Parameters ---------- history : pd.DataFrame or bayesflow.LossHistory object The (plottable) history as returned by a train_[...] method of a `Trainer` instance. Returns ------- f : plt.Figure - the figure instance for optional saving """ # Handle non-pd.DataFrame type if type(history) is LossHistory: history = history.get_plottable() # Determine the number of rows for plot n_row = len(history.columns) # Initialize figure if fig_size is None: fig_size = (16, int(4 * n_row)) f, axarr = plt.subplots(n_row, 1, figsize=fig_size) # Get the number of steps as an array step_index = np.arange(1, len(history)+1) # Loop through loss entries and populate plot looper = [axarr] if n_row == 1 else axarr.flat for i, ax in enumerate(looper): ax.plot(step_index, history.iloc[:, i], color=color, lw=2) ax.set_xlabel('Training step #', fontsize=label_fontsize) ax.set_ylabel('Loss value', fontsize=label_fontsize) sns.despine(ax=ax) ax.grid(alpha=0.5) ax.set_title(history.columns[i], fontsize=title_fontsize) f.tight_layout() return f
[docs]def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color='#8f2727', **kwargs): """Creates pairplots for a given joint prior. Parameters ---------- prior : callable The prior object which takes a single integer argument and generates random draws. param_names : list of str or None, optional, default None An optional list of strings which n_samples : int, optional, default: 1000 The number of random draws from the joint prior height : float, optional, default: 2.5 The height of the pair plot color : str, optional, defailt : '#8f2727' The color of the plot **kwargs : dict, optional Additional keyword arguments passed to the sns.PairGrid constructor Returns ------- f : plt.Figure - the figure instance for optional saving """ # Generate prior draws prior_samples = prior(n_samples) # Handle dict type if type(prior_samples) is dict: prior_samples = prior_samples['prior_draws'] # Get latent dimensionality and prepare titles dim = prior_samples.shape[-1] # Convert samples to a pandas data frame if param_names is None: titles = [f'Prior Param. {i}' for i in range(1, dim+1)] else: titles = [f'Prior {p}' for p in param_names] data_to_plot = pd.DataFrame(prior_samples, columns=titles) # Generate plots g = sns.PairGrid(data_to_plot, height=height, **kwargs) g.map_diag(sns.histplot, fill=True, color=color, alpha=0.9, kde=True) # Kernel density estimation (KDE) may not always be possible (e.g. with parameters whose correlation is close to 1 or -1). # In this scenario, a scatter-plot is generated instead. try: g.map_lower(sns.kdeplot, fill=True, color=color, alpha=0.9) except Exception as e: logging.warn("KDE failed due to the following exception:\n"+repr(e)+"\nSubstituting scatter plot.") g.map_lower(plt.scatter, alpha=0.6, s=40, edgecolor='k', color=color) g.map_upper(plt.scatter, alpha=0.6, s=40, edgecolor='k', color=color) # Add grids for i in range(dim): for j in range(dim): g.axes[i, j].grid(alpha=0.5) g.tight_layout() return g.fig
[docs]def plot_latent_space_2d(z_samples, height=2.5, color='#8f2727', **kwargs): """Creates pairplots for the latent space learned by the inference network. Enables visual inspection of the the latent space and whether its structrue corresponds to the one enforced by the optimization criterion. Parameters ---------- z_samples : np.ndarray or tf.Tensor of shape (n_sim, n_params) The latent samples computed through a forward pass of the inference network. height : float, optional, default: 2.5 The height of the pair plot. color : str, optional, defailt : '#8f2727' The color of the plot **kwargs : dict, optional Additional keyword arguments passed to the sns.PairGrid constructor Returns ------- f : plt.Figure - the figure instance for optional saving """ # Try to convert z_samples, if eventually tf.Tensor is passed if type(z_samples) is not np.ndarray: z_samples = z_samples.numpy() # Get latent dimensionality and prepare titles z_dim = z_samples.shape[-1] # Convert samples to a pandas data frame titles = [f'Latent Dim. {i}' for i in range(1, z_dim+1)] data_to_plot = pd.DataFrame(z_samples, columns=titles) # Generate plots g = sns.PairGrid(data_to_plot, height=height, **kwargs) g.map_diag(sns.histplot, fill=True, color=color, alpha=0.9, kde=True) g.map_lower(sns.kdeplot, fill=True, color=color, alpha=0.9) g.map_upper(plt.scatter, alpha=0.6, s=40, edgecolor='k', color=color) # Add grids for i in range(z_dim): for j in range(z_dim): g.axes[i, j].grid(alpha=0.5) g.tight_layout() return g.fig
[docs]def plot_calibration_curves(m_true, m_pred, model_names=None, n_bins=10, font_size=12, fig_size=(12, 4)): """ Plots the calibration curves and the ECE for a model comparison problem. Depends on the `expected_calibration_error` function for computing the ECE. Parameters ---------- TODO """ n_models = m_pred.shape[-1] if model_names is None: model_names = [fr'$M_{{{m}}}$' for m in range(1, n_models+1)] # Determine n_subplots dynamically n_row = int(np.ceil(n_models / 6)) n_col = int(np.ceil(n_models / n_row)) cal_errs, cal_probs = expected_calibration_error(m_true, m_pred, n_bins) # Initialize figure f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) if n_row > 1: ax = axarr.flat # Plot marginal calibration curves in a loop if n_row > 1: ax = axarr.flat else: ax = axarr for j in range(n_models): # Plot calibration curve ax[j].plot(cal_probs[j][0], cal_probs[j][1]) # Plot AB line ax[j].plot(ax[j].get_xlim(), ax[j].get_xlim(), '--', color='black') # Tweak plot ax[j].spines['right'].set_visible(False) ax[j].spines['top'].set_visible(False) ax[j].set_xlim([0, 1]) ax[j].set_ylim([0, 1]) ax[j].set_xlabel('Accuracy') ax[j].set_ylabel('Confidence') ax[j].set_xticks([0.2, 0.4, 0.6, 0.8, 1.0]) ax[j].set_yticks([0.2, 0.4, 0.6, 0.8, 1.0]) ax[j].text(0.1, 0.9, r'$\widehat{{ECE}}$ = {0:.3f}'.format(cal_errs[j]), horizontalalignment='left', verticalalignment='center', transform=ax[j].transAxes, size=font_size) # Set title ax[j].set_title(model_names[j]) f.tight_layout() return f