Skip to content

Commit

Permalink
Update defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed May 22, 2023
1 parent 31a6251 commit 9768a4d
Show file tree
Hide file tree
Showing 3 changed files with 508 additions and 7 deletions.
173 changes: 173 additions & 0 deletions bayesflow/amortizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,179 @@ def _determine_loss(self, loss_fun):
)


class TwoLevelAmortizedPosterior(tf.keras.Model, AmortizedTarget):
"""An interface for estimating arbitrary two level hierarchical Bayesian models."""

def __init__(self, local_amortizer, global_amortizer, summary_net=None, **kwargs):
"""Creates an wrapper for estimating two-level hierarchical Bayesian models.
Parameters
----------
local_amortizer : bayesflow.amortizers.AmortizedPosterior
A posterior amortizer without a summary network which will estimate
the full conditional of the (varying numbers of) local parameter vectors.
global_amortizer : bayesflow.amortizers.AmortizedPosterior
A posterior amortizer without a summary network which will estimate the joint
posterior of hyperparameters and optional shared parameters given a representation
of an entire hierarchical data set. If both hyper- and shared parameters are present,
the first dimensions correspond to the hyperparameters and the remaining ones correspond
to the shared parameters.
summary_net : tf.keras.Model or None, optional, default: None
An optional summary network to compress non-vector data structures.
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the ``__init__`` method of a ``tf.keras.Model`` instance.
"""

super().__init__(**kwargs)

self.local_amortizer = local_amortizer
self.global_amortizer = global_amortizer
self.summary_net = summary_net

def call(self, input_dict, **kwargs):
"""Forward pass through the hierarchical amortized posterior."""

local_summaries, global_summaries = self._compute_condition(input_dict, **kwargs)
local_inputs, global_inputs = self._prepare_inputs(input_dict, local_summaries, global_summaries)
local_out = self.local_amortizer(local_inputs, **kwargs)
global_out = self.global_amortizer(global_inputs, **kwargs)
return local_out, global_out

def compute_loss(self, input_dict, **kwargs):
"""Compute loss of all amortizers."""

local_summaries, global_summaries = self._compute_condition(input_dict, **kwargs)
local_inputs, global_inputs = self._prepare_inputs(input_dict, local_summaries, global_summaries)
local_loss = self.local_amortizer.compute_loss(local_inputs, **kwargs)
global_loss = self.global_amortizer.compute_loss(global_inputs, **kwargs)
return {"Local.Loss": local_loss, "Global.Loss": global_loss}

def sample(self, input_dict, n_samples, to_numpy=True, **kwargs):
"""Obtains samples from the joint hierarchical posterior given observations.
Important: Currently works only for single hierarchical data sets!
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if DEFAULT_KEYS unchanged:
`summary_conditions` - the hierarchical data set (to be embedded by the summary net)
As well as optional keys:
`direct_local_conditions` - (Context) variables used to condition the local posterior
`direct_global_conditions` - (Context) variables used to condition the global posterior
n_samples : int
The number of posterior draws (samples) to obtain from the approximate posterior
to_numpy : bool, optional, default: True
Flag indicating whether to return the samples as a `np.array` or a `tf.Tensor`
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the summary network as the amortizers
Returns:
--------
samples_dict : dict
A dictionary with keys `global_samples` and `local_samples`
Global samples will hold an array-like of shape (num_samples, num_replicas, num_local)
and local samples will hold an array-like of shape (num_samples, num_hyper + num_shared),
if optional shared patameters are present, otherwise (num_samples, num_hyper),
"""

# Returned shapes will be
# local_summaries.shape = (1, num_groups, summary_dim_local)
# global_summaries.shape = (1, summary_dim_global)
local_summaries, global_summaries = self._get_local_global(input_dict, **kwargs)
num_groups = local_summaries.shape[1]

if local_summaries.shape[0] != 1 or global_summaries.shape[0] != 1:
raise NotImplementedError("Method currently supports only single hierarchical data sets!")

# Obtain samples from p(global | all_data)
inp_global = {DEFAULT_KEYS["direct_conditions"]: global_summaries}

# New, shape will be (n_samples, num_globals)
global_samples = self.global_amortizer.sample(inp_global, n_samples, **kwargs, to_numpy=False)

# Repeat local conditions for n_samples
# New shape -> (num_groups, n_samples, summary_dim_local)
local_summaries = tf.stack([tf.squeeze(local_summaries, axis=0)] * n_samples, axis=1)

# Repeat global samples for num_groups
# New shape -> (num_groups, n_samples, num_globals)
global_samples_rep = tf.stack([global_samples] * num_groups, axis=0)

# Concatenate local summaries with global samples
# New shape -> (num_groups, num_samples, summary_dim_local + num_globals)
local_summaries = tf.concat([local_summaries, global_samples_rep], axis=-1)

# Obtain samples from p(local_i | data_i, global_i)
inp_local = {DEFAULT_KEYS["direct_conditions"]: local_summaries}
local_samples = self.local_amortizer.sample(inp_local, n_samples, to_numpy=False, **kwargs)

if to_numpy:
global_samples = global_samples.numpy()
local_samples = local_samples.numpy()

return {"global_samples": global_samples, "local_samples": local_samples}

def log_prob(self, input_dict):
"""Compute normalized log density."""

raise NotImplementedError

def _prepare_inputs(self, input_dict, local_summaries, global_summaries):
"""Prepare input dictionaries for both amortizers."""

# Prepare inputs for local amortizer
local_inputs = {"direct_conditions": local_summaries, "parameters": input_dict["local_parameters"]}

# Prepare inputs for global amortizer
_parameters = input_dict["hyper_parameters"]
if input_dict.get("shared_parameters") is not None:
_parameters = tf.concat([_parameters, input_dict.get("shared_parameters")], axis=-1)
global_inputs = {"direct_conditions": global_summaries, "parameters": _parameters}
return local_inputs, global_inputs

def _compute_condition(self, input_dict, **kwargs):
"""Determines conditionining variables for both amortizers."""

# Obtain needed summaries
local_summaries, global_summaries = self._get_local_global(input_dict, **kwargs)

# At this point, add globals as conditions
num_locals = local_summaries.shape[1]

# Add hyper parameters as conditions:
# p(local_n | data_n, hyper)
if input_dict.get("hyper_parameters") is not None:
_params = input_dict.get("hyper_parameters")
_conds = tf.stack([_params] * num_locals, axis=1)
local_summaries = tf.concat([local_summaries, _conds], axis=-1)
# Add shared parameters as conditions:
# p(local_n | data_n, hyper, shared)
if input_dict.get("shared_parameters") is not None:
_params = input_dict.get("shared_parameters")
_conds = tf.stack([_params] * num_locals, axis=1)
local_summaries = tf.concat([local_summaries, _conds], axis=-1)
return local_summaries, global_summaries

def _get_local_global(self, input_dict, **kwargs):
"""Helper function to obtain local and global condition tensors."""

# Obtain summary conditions
if self.summary_net is not None:
local_summaries, global_summaries = self.summary_net(
input_dict["summary_conditions"], return_all=True, **kwargs
)
if input_dict.get("direct_local_conditions") is not None:
local_summaries = tf.concat([local_summaries, input_dict.get("direct_local_conditions")], axis=-1)
if input_dict.get("direct_global_conditions") is not None:
global_summaries = tf.concat([global_summaries, input_dict.get("direct_global_conditions")], axis=-1)
# If no summary net provided, assume direct conditions exist or fail
else:
local_summaries = input_dict.get("direct_local_conditions")
global_summaries = input_dict.get("direct_global_conditions")
return local_summaries, global_summaries


class SingleModelAmortizer(AmortizedPosterior):
"""Deprecated class for amortizer posterior estimation."""

Expand Down
8 changes: 7 additions & 1 deletion bayesflow/default_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,18 @@ def __init__(self, meta_dict: dict, mandatory_fields: list = []):
"non_batchable_context": "non_batchable_context",
"prior_batchable_context": "prior_batchable_context",
"prior_non_batchable_context": "prior_non_batchable_context",
"prior_context": "prior_context",
"hyper_prior_draws": "hyper_prior_draws",
"shared_prior_draws": "shared_prior_draws",
"local_prior_draws": "local_prior_draws",
"sim_batchable_context": "sim_batchable_context",
"sim_non_batchable_context": "sim_non_batchable_context",
"summary_conditions": "summary_conditions",
"direct_conditions": "direct_conditions",
"parameters": "parameters",
"hyperparameters": "hyperparameters",
"hyper_parameters": "hyper_parameters",
"shared_parameters": "shared_parameters",
"local_parameters": "local_parameters",
"observables": "observables",
"targets": "targets",
"conditions": "conditions",
Expand Down
Loading

0 comments on commit 9768a4d

Please sign in to comment.