Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add n_row and n_col argument where applicable #109

Merged
merged 1 commit into from
Nov 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
add n_row and n_col argument where applicable
  • Loading branch information
LuSchumacher committed Nov 9, 2023
commit d22a996efc3ac7dbc3dbb8472799d5106e57e18f
63 changes: 53 additions & 10 deletions bayesflow/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def plot_recovery(
A flag for adding R^2 between true and estimates to the plot
color : str, optional, default: '#8f2727'
The color for the true vs. estimated scatter points and error bars
n_row : int, optional, default: None
The number of rows for the subplots. Dynamically determined if None.
n_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.
xlabel : str, optional, default: 'Ground truth'
The label on the x-axis of the plot
ylabel : str, optional, default: 'Estimated'
Expand Down Expand Up @@ -232,7 +236,7 @@ def plot_z_score_contraction(
tick_fontsize=12,
color="#8f2727",
n_col=None,
n_row=None,
n_row=None
):
"""Implements a graphical check for global model sensitivity by plotting the posterior
z-score over the posterior contraction for each set of posterior samples in ``post_samples``
Expand Down Expand Up @@ -279,6 +283,10 @@ def plot_z_score_contraction(
The font size of the axis ticklabels
color : str, optional, default: '#8f2727'
The color for the true vs. estimated scatter points and error bars
n_row : int, optional, default: None
The number of rows for the subplots. Dynamically determined if None.
n_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.

Returns
-------
Expand Down Expand Up @@ -379,6 +387,8 @@ def plot_sbc_ecdf(
tick_fontsize=12,
rank_ecdf_color="#a34f4f",
fill_color="grey",
n_row=None,
n_col=None,
**kwargs,
):
"""Creates the empirical CDFs for each marginal rank distribution and plots it against
Expand Down Expand Up @@ -419,6 +429,10 @@ def plot_sbc_ecdf(
The color to use for the rank ECDFs
fill_color : str, optional, default: 'grey'
The color of the fill arguments.
n_row : int, optional, default: None
The number of rows for the subplots. Dynamically determined if None.
n_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.
**kwargs : dict, optional, default: {}
Keyword arguments can be passed to control the behavior of ECDF simultaneous band computation
through the ``ecdf_bands_kwargs`` dictionary. See `simultaneous_ecdf_bands` for keyword arguments
Expand Down Expand Up @@ -447,9 +461,14 @@ def plot_sbc_ecdf(
n_row, n_col = 1, 1
f, ax = plt.subplots(1, 1, figsize=fig_size)
else:
# Determine n_subplots dynamically
n_row = int(np.ceil(n_params / 6))
n_col = int(np.ceil(n_params / n_row))
# Determine number of rows and columns for subplots based on inputs
if n_row is None and n_col is None:
n_row = int(np.ceil(n_params / 6))
n_col = int(np.ceil(n_params / n_row))
elif n_row is None and n_col is not None:
n_row = int(np.ceil(n_params / n_col))
elif n_row is not None and n_col is None:
n_col = int(np.ceil(n_params / n_row))

# Determine fig_size dynamically, if None
if fig_size is None:
Expand Down Expand Up @@ -543,6 +562,8 @@ def plot_sbc_histograms(
title_fontsize=18,
tick_fontsize=12,
hist_color="#a34f4f",
n_row=None,
n_col=None
):
"""Creates and plots publication-ready histograms of rank statistics for simulation-based calibration
(SBC) checks according to [1].
Expand Down Expand Up @@ -576,6 +597,10 @@ def plot_sbc_histograms(
The font size of the axis ticklabels
hist_color : str, optional, default '#a34f4f'
The color to use for the histogram body
n_row : int, optional, default: None
The number of rows for the subplots. Dynamically determined if None.
n_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.

Returns
-------
Expand Down Expand Up @@ -615,9 +640,14 @@ def plot_sbc_histograms(
if param_names is None:
param_names = [f"$\\theta_{{{i}}}$" for i in range(1, n_params + 1)]

# Determine n_subplots dynamically
n_row = int(np.ceil(n_params / 6))
n_col = int(np.ceil(n_params / n_row))
# Determine number of rows and columns for subplots based on inputs
if n_row is None and n_col is None:
n_row = int(np.ceil(n_params / 6))
n_col = int(np.ceil(n_params / n_row))
elif n_row is None and n_col is not None:
n_row = int(np.ceil(n_params / n_col))
elif n_row is not None and n_col is None:
n_col = int(np.ceil(n_params / n_row))

# Initialize figure
if fig_size is None:
Expand Down Expand Up @@ -1026,6 +1056,8 @@ def plot_calibration_curves(
epsilon=0.02,
fig_size=None,
color="#8f2727",
n_row=None,
n_col=None
):
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
Expand Down Expand Up @@ -1055,6 +1087,10 @@ def plot_calibration_curves(
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
color : str, optional, default: '#8f2727'
The color of the calibration curves
n_row : int, optional, default: None
The number of rows for the subplots. Dynamically determined if None.
n_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.

Returns
-------
Expand All @@ -1065,9 +1101,15 @@ def plot_calibration_curves(
if model_names is None:
model_names = [rf"$M_{{{m}}}$" for m in range(1, num_models + 1)]

# Determine n_subplots dynamically
n_row = int(np.ceil(num_models / 6))
n_col = int(np.ceil(num_models / n_row))
# Determine number of rows and columns for subplots based on inputs
if n_row is None and n_col is None:
n_row = int(np.ceil(num_models / 6))
n_col = int(np.ceil(num_models / n_row))
elif n_row is None and n_col is not None:
n_row = int(np.ceil(num_models / n_col))
elif n_row is not None and n_col is None:
n_col = int(np.ceil(num_models / n_row))


# Compute calibration
cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins)
Expand Down Expand Up @@ -1233,6 +1275,7 @@ def plot_confusion_matrix(
ax.set_title("Confusion Matrix", fontsize=title_fontsize)
return fig


def plot_mmd_hypothesis_test(
mmd_null,
mmd_observed=None,
Expand Down
Loading