Source code for bayesflow.coupling_networks

# Copyright (c) 2022 The BayesFlow Developers

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import tensorflow as tf

from numpy import pi as PI_CONST

from bayesflow import default_settings
from bayesflow.helper_networks import Permutation, ActNorm, DenseCouplingNet
from bayesflow.helper_functions import build_meta_dict
from bayesflow.exceptions import ConfigurationError


[docs]class AffineCouplingLayer(tf.keras.Model): """Implements a conditional version of the INN coupling layer.""" def __init__(self, meta): """Creates an affine coupling layer, optionally conditional. Parameters ---------- meta : list(dict) A list of dictionaries, wherein each dictionary holds parameter-value pairs for a single :class:`tf.keras.Dense` layer. All coupling nets are assumed to be equal. """ super().__init__() # Coupling net hyperparams self.alpha = meta['alpha'] self.latent_dim = meta['latent_dim'] self.n_out1 = self.latent_dim // 2 self.n_out2 = self.latent_dim // 2 if self.latent_dim % 2 == 0 else self.latent_dim // 2 + 1 # Custom coupling net and settings if callable(meta['coupling_design']): coupling_type = meta['coupling_design'] if meta.get('coupling_net_settings') is None: raise ConfigurationError("You need to provide 'coupling_net_settings' for a custom coupling type!") coupling_net_settings = meta['coupling_net_settings'] # String type of dense or attention elif type(meta['coupling_design']) is str: # Settings type if meta.get('coupling_net_settings') is None: user_dict = {} elif type(meta.get('coupling_net_settings')) is dict: user_dict = meta.get('coupling_net_settings') else: raise ConfigurationError("coupling_net_settings not understood") # Dense if meta['coupling_design'] == 'dense': coupling_type = DenseCouplingNet coupling_net_settings = build_meta_dict( user_dict=user_dict, default_setting=default_settings.DEFAULT_SETTING_DENSE_COUPLING) else: raise NotImplementedError('String coupling_design should be one of ["dense"].') else: raise NotImplementedError('coupling_design argument not understood. Should either be a callable generator or ' + 'a string in ["dense"].') self.s1 = coupling_type(coupling_net_settings['s_args'], self.n_out1) self.t1 = coupling_type(coupling_net_settings['t_args'], self.n_out1) self.s2 = coupling_type(coupling_net_settings['s_args'], self.n_out2) self.t2 = coupling_type(coupling_net_settings['t_args'], self.n_out2) # Optional permutation if meta['use_permutation']: self.permutation = Permutation(self.latent_dim) self.permutation.trainable = False else: self.permutation = None # Optional activation normalization if meta['use_act_norm']: self.act_norm = ActNorm(meta) else: self.act_norm = None
[docs] def call(self, target_or_z, condition, inverse=False, **kwargs): """Performs one pass through a the affine coupling layer (either inverse or forward). Parameters ---------- target_or_z : tf.Tensor The estimation quantites of interest or latent representations z ~ p(z), shape (batch_size, ...) condition : tf.Tensor or None The conditioning data of interest, for instance, x = summary_fun(x), shape (batch_size, ...). If `condition is None`, then the layer recuces to an unconditional ACL. inverse : bool, optional, default: False Flag indicating whether to run the block forward or backward. Returns ------- (z, log_det_J) : tuple(tf.Tensor, tf.Tensor) If inverse=False: The transformed input and the corresponding Jacobian of the transformation, z shape: (batch_size, inp_dim), log_det_J shape: (batch_size, ) target : tf.Tensor If inverse=True: The back-transformed z, shape (batch_size, inp_dim) Important --------- If ``inverse=False``, the return is ``(z, log_det_J)``.\n If ``inverse=True``, the return is ``target`` """ if not inverse: return self.forward(target_or_z, condition, **kwargs) return self.inverse(target_or_z, condition, **kwargs)
[docs] @tf.function def forward(self, target, condition, **kwargs): """Performs a forward pass through a coupling layer with an optinal `Permutation` and `ActNorm` layers. Parameters ---------- target : tf.Tensor The estimation quantities of interest, for instance, parameter vector of shape (batch_size, theta_dim) condition : tf.Tensor or None The conditioning vector of interest, for instance, x = summary(x), shape (batch_size, summary_dim) If `None`, transformation amounts to unconditional estimation. Returns ------- (z, log_det_J) : tuple(tf.Tensor, tf.Tensor) The transformed input and the corresponding Jacobian of the transformation. """ # Initialize log_det_Js accumulator log_det_Js = tf.zeros(1) # Normalize activation, if specified if self.act_norm is not None: target, log_det_J_act = self.act_norm(target) log_det_Js += log_det_J_act # Permute, if indicated if self.permutation is not None: target = self.permutation(target) # Pass through coupling layer z, log_det_J_c = self._forward(target, condition, **kwargs) log_det_Js += log_det_J_c return z, log_det_Js
[docs] @tf.function def inverse(self, z, condition, **kwargs): """Performs an inverse pass through a coupling layer with an optinal `Permutation` and `ActNorm` layers. Parameters ---------- z : tf.Tensor latent variables z ~ p(z), shape (batch_size, theta_dim) condition : tf.Tensor or None The conditioning vector of interest, for instance, x = summary(x), shape (batch_size, summary_dim). If `None`, transformation amounts to unconditional estimation. Returns ------- target : tf.Tensor The back-transformed latent variable z. """ # Pass through coupling layer target = self._inverse(z, condition, **kwargs) # Pass through optional permutation if self.permutation is not None: target = self.permutation(target, inverse=True) # Pass through activation normalization if self.act_norm is not None: target = self.act_norm(target, inverse=True) return target
@tf.function def _forward(self, target, condition, **kwargs): """Performs a forward pass through the coupling layer. Used internally by the instance. Parameters ---------- target : tf.Tensor The estimation quantities of interest, for instance, parameter vector of shape (batch_size, theta_dim) condition : tf.Tensor or None The conditioning vector of interest, for instance, x = summary(x), shape (batch_size, summary_dim) If `None`, transformation amounts to unconditional estimation. Returns ------- (v, log_det_J) : tuple(tf.Tensor, tf.Tensor) The transformed input and the corresponding Jacobian of the transformation. """ # Split parameter vector u1, u2 = tf.split(target, [self.n_out1, self.n_out2], axis=-1) # Pre-compute network outputs for v1 s1 = self.s1(u2, condition, **kwargs) # Clamp s1 if specified if self.alpha is not None: s1 = (2. * self.alpha / PI_CONST) * tf.math.atan(s1 / self.alpha) t1 = self.t1(u2, condition, **kwargs) v1 = u1 * tf.exp(s1) + t1 # Pre-compute network outputs for v2 s2 = self.s2(v1, condition, **kwargs) # Clamp s2 if specified if self.alpha is not None: s2 = (2. * self.alpha / PI_CONST) * tf.math.atan(s2 / self.alpha) t2 = self.t2(v1, condition, **kwargs) v2 = u2 * tf.exp(s2) + t2 v = tf.concat((v1, v2), axis=-1) # Compute ldj, # log|J| = log(prod(diag(J))) -> according to inv architecture log_det_J = tf.reduce_sum(s1, axis=-1) + tf.reduce_sum(s2, axis=-1) return v, log_det_J @tf.function def _inverse(self, z, condition, **kwargs): """Performs an inverse pass through the coupling block. Used internally by the instance. Parameters ---------- z : tf.Tensor latent variables z ~ p(z), shape (batch_size, theta_dim) condition : tf.Tensor or None The conditioning vector of interest, for instance, x = summary(x), shape (batch_size, summary_dim). If `None`, transformation amounts to unconditional estimation. Returns ------- u : tf.Tensor The back-transformed input. """ v1, v2 = tf.split(z, [self.n_out1, self.n_out2], axis=-1) # Pre-Compute s2 s2 = self.s2(v1, condition, **kwargs) # Clamp s2 if specified if self.alpha is not None: s2 = (2. * self.alpha / PI_CONST) * tf.math.atan(s2 / self.alpha) u2 = (v2 - self.t2(v1, condition, **kwargs)) * tf.exp(-s2) # Pre-Compute s1 s1 = self.s1(u2, condition, **kwargs) # Clamp s1 if specified if self.alpha is not None: s1 = (2. * self.alpha / PI_CONST) * tf.math.atan(s1 / self.alpha) u1 = (v1 - self.t1(u2, condition, **kwargs)) * tf.exp(-s1) u = tf.concat((u1, u2), axis=-1) return u