Skip to content

Commit

Permalink
Merge pull request #68 from vpratz/hierarchical
Browse files Browse the repository at this point in the history
Small adaptations to hierarchical, merged Development branch
  • Loading branch information
stefanradev93 committed Apr 18, 2023
2 parents 6482c52 + f06f279 commit cc191d5
Show file tree
Hide file tree
Showing 13 changed files with 1,643 additions and 54 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ General Improvements:
1. Improved docstrings and consistent use of keyword arguments vs. configuration dictionaries
2. Increased focus on transformer-based architectures as summary networks
3. Figures resulting ``diagnostics.py`` have been improved and prettified
4. Multiple bugfixes
4. Added a module ``sensitivity.py`` for testing the sensitivity of neural approximators to model misspecification
5. Multiple bugfixes
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ The amortizer knows how to combine its losses.

### References and Further Reading

- Schmitt, M., Bürkner P. C., Köthe U., & Radev S. T. (2022). Detecting Model
- Schmitt, M., Bürkner P. C., Köthe U., & Radev S. T. (2021). Detecting Model
Misspecification in Amortized Bayesian Inference with Neural Networks. <em>ArXiv
preprint</em>, available for free at: https://arxiv.org/abs/2112.08866

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

# Add easy access imports
try:
from bayesflow import amortizers, default_settings, diagnostics, losses, networks, trainers
from bayesflow import amortizers, default_settings, diagnostics, losses, networks, trainers, sensitivity
except ImportError as err:
logger = logging.getLogger()
logger.setLevel(logging.WARNING)
Expand Down
85 changes: 85 additions & 0 deletions bayesflow/computational_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sklearn.calibration import calibration_curve

from bayesflow.default_settings import MMD_BANDWIDTH_LIST
from bayesflow.exceptions import ShapeError


def compute_jacobian_trace(function, inputs, **kwargs):
Expand Down Expand Up @@ -350,3 +351,87 @@ def simultaneous_ecdf_bands(
L = stats.binom(N, z).ppf(gamma / 2) / N
U = stats.binom(N, z).ppf(1 - gamma / 2) / N
return alpha, z, L, U


def mean_squared_error(x_true, x_pred):
"""Computes the mean squared error between a single true value and M estimates thereof.
x_true : np.ndarray
true values, shape ()
x_pred : np.ndarray
predicted values, shape (M, )
"""
x_true = np.array(x_true)
x_pred = np.array(x_pred)
try:
return np.mean((x_true[np.newaxis, :] - x_pred) ** 2)
except IndexError:
return np.mean((x_true - x_pred) ** 2)


def root_mean_squared_error(x_true, x_pred):
"""Computes the mean squared error between a single true value and M estimates thereof.
x_true : np.ndarray
true values, shape ()
x_pred : np.ndarray
predicted values, shape (M, )
"""

mse = mean_squared_error(x_true=x_true, x_pred=x_pred)
return np.sqrt(mse)


def aggregated_error(x_true, x_pred, inner_error_fun=root_mean_squared_error, outer_aggregation_fun=np.mean):
"""Computes the aggregated error between a vector of N true values and M estimates of each true value.
x_true : np.ndarray
true values, shape (N)
x_pred : np.ndarray
predicted values, shape (M, N)
inner_error_fun: callable, default: root_mean_squared_error
computes the error between one true value and M estimates thereof
outer_aggregation_fun: callable, default: np.mean
aggregates N errors to a single aggregated error value
"""

x_true, x_pred = np.array(x_true), np.array(x_pred)

N = x_pred.shape[0]
if not N == x_true.shape[0]:
raise ShapeError

errors = np.array([inner_error_fun(x_true=x_true[i], x_pred=x_pred[i]) for i in range(N)])

if not N == errors.shape[0]:
raise ShapeError

return outer_aggregation_fun(errors)


def aggregated_rmse(x_true, x_pred):
"""
Computes the aggregated RMSE for a matrix of predictions.
Parameters
----------
x_true : np.ndarray
true values, shape (N)
x_pred : np.ndarray
predicted values, shape (M, N)
Returns
-------
aggregated RMSE
"""

return aggregated_error(x_true=x_true,
x_pred=x_pred,
inner_error_fun=root_mean_squared_error,
outer_aggregation_fun=np.mean)
27 changes: 17 additions & 10 deletions bayesflow/coupling_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,30 +330,38 @@ def _calculate_spline(self, target, spline_params, inverse=False):
# Extract all learnable parameters
left_edge, bottom_edge, widths, heights, derivatives = spline_params

# Placeholders for results
result = tf.zeros_like(target)
log_jac = tf.zeros_like(target)

total_width = tf.reduce_sum(widths, axis=-1, keepdims=True)
total_height = tf.reduce_sum(heights, axis=-1, keepdims=True)

knots_x = tf.concat([left_edge, left_edge + tf.math.cumsum(widths, axis=-1)], axis=-1)
knots_y = tf.concat([bottom_edge, bottom_edge + tf.math.cumsum(heights, axis=-1)], axis=-1)

log_jac = tf.zeros_like(target)
result = tf.zeros_like(target)

# Determine which targets are in domain and which are not
target_in_domain = tf.logical_and(knots_x[..., 0] < target, target <= knots_x[..., -1])
if not inverse:
target_in_domain = tf.logical_and(knots_x[..., 0] < target, target <= knots_x[..., -1])
higher_indices = tf.searchsorted(knots_x, target[..., None])
else:
target_in_domain = tf.logical_and(knots_y[..., 0] < target, target <= knots_y[..., -1])
higher_indices = tf.searchsorted(knots_y, target[..., None])
target_in = target[target_in_domain]
target_in_idx = tf.where(target_in_domain)
target_out = target[~target_in_domain]
target_out_idx = tf.where(~target_in_domain)
higher_indices = tf.searchsorted(knots_x, target[..., None])

# In-domain computation
if tf.size(target_in_idx) > 0:
# Index crunching
higher_indices = tf.gather_nd(higher_indices, target_in_idx)
higher_indices = tf.cast(higher_indices, tf.int32)
lower_indices = higher_indices - 1
lower_idx_tuples = tf.concat([tf.cast(target_in_idx, tf.int32), lower_indices], axis=-1)
higher_idx_tuples = tf.concat([tf.cast(target_in_idx, tf.int32), higher_indices], axis=-1)

# Spline computation
dk = tf.gather_nd(derivatives, lower_idx_tuples)
dkp = tf.gather_nd(derivatives, higher_idx_tuples)
xk = tf.gather_nd(knots_x, lower_idx_tuples)
Expand All @@ -371,10 +379,9 @@ def _calculate_spline(self, target, spline_params, inverse=False):
numerator = dy * (sk * xi**2 + dk * xi * (1 - xi))
denominator = sk + (dkp + dk - 2 * sk) * xi * (1 - xi)
result_in = yk + numerator / denominator
# Log Jacobian for in-domain
numerator = sk**2 * (dkp * xi**2 + 2 * sk * xi * (1 - xi) + dk * (1 - xi) ** 2)
denominator = (sk + (dkp + dk - 2 * sk) * xi * (1 - xi)) ** 2

# Jacobian for in-domain points
log_jac_in = tf.math.log(numerator + 1e-10) - tf.math.log(denominator + 1e-10)
log_jac = tf.tensor_scatter_nd_update(log_jac, target_in_idx, log_jac_in)
# Inverse pass
Expand All @@ -398,7 +405,7 @@ def _calculate_spline(self, target, spline_params, inverse=False):

if not inverse:
result_out = scale_out * target_out[..., None] + shift_out
# Jacobian for out-of-domain points
# Log Jacobian for out-of-domain points
log_jac_out = tf.math.log(scale_out + 1e-10)
log_jac_out = tf.squeeze(log_jac_out, axis=-1)
log_jac = tf.tensor_scatter_nd_update(log_jac, target_out_idx, log_jac_out)
Expand Down Expand Up @@ -460,8 +467,8 @@ def _constrain_parameters(self, parameters):
bottom_edge = bottom_edge + self.default_domain[2]

# Compute default widths and heights
default_width = self.default_domain[1] - self.default_domain[0]
default_height = self.default_domain[3] - self.default_domain[2]
default_width = (self.default_domain[1] - self.default_domain[0]) / self.bins
default_height = (self.default_domain[3] - self.default_domain[2]) / self.bins

# Compute shifts for softplus function
xshift = tf.math.log(tf.math.exp(default_width) - 1)
Expand Down
4 changes: 2 additions & 2 deletions bayesflow/default_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ def __init__(self, meta_dict: dict, mandatory_fields: list = []):
"mc_dropout": False,
"dropout": True,
"residual": False,
"dropout_prob": 0.1,
"dropout_prob": 0.05,
"bins": 16,
"default_domain": (-10.0, 10.0, -10.0, 10.0),
"default_domain": (-5.0, 5.0, -5.0, 5.0),
},
mandatory_fields=[],
)
Expand Down
Loading

0 comments on commit cc191d5

Please sign in to comment.