Skip to content

Commit

Permalink
Format nicely
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Jul 1, 2024
1 parent 7f7db44 commit 8313c7d
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 15 deletions.
2 changes: 1 addition & 1 deletion bayesflow/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, name: str, **kwargs):

self.name = name
self.module = self.get_module(name)
self.simulator = partial(getattr(self.module, "simulator"), **kwargs)
self.simulator = partial(getattr(self.module, "simulator"), **kwargs.pop("prior_kwargs", {}))

def sample(self, batch_size: int):
return batched_call(self.simulator, (batch_size,))
Expand Down
7 changes: 4 additions & 3 deletions bayesflow/benchmarks/two_moons.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np


def simulator(lower_bound: float = -1.0, upper_bound: float = 1.0, rng: np.random.Generator = None):
prior_draws = prior(lower_bound, upper_bound, rng)
observables = observation_model(prior_draws, rng)
def simulator():
"""Non-configurable simulator running with default settings."""
prior_draws = prior()
observables = observation_model(prior_draws)
return dict(parameters=prior_draws, observables=observables)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@register_keras_serializable(package="bayesflow.networks.coupling_flow")
class AffineTransform(Transform):
def __init__(self, clamp_factor: int | float = 5.0, **kwargs):
def __init__(self, clamp_factor: float | None = 5.0, **kwargs):
super().__init__(**kwargs)
self.clamp_factor = clamp_factor

Expand Down Expand Up @@ -36,8 +36,9 @@ def split_parameters(self, parameters: Tensor) -> dict[str, Tensor]:
return {"scale": scale, "shift": shift}

def constrain_parameters(self, parameters: dict[str, Tensor]) -> dict[str, Tensor]:
s = parameters["scale"]
parameters["scale"] = 1 / (1 + ops.exp(-s)) * ops.sqrt(1 + ops.abs(s + self.clamp_factor))
if self.clamp_factor is not None:
s = parameters["scale"]
parameters["scale"] = 1 / (1 + ops.exp(-s)) * ops.sqrt(1 + ops.abs(s + self.clamp_factor))

return parameters

Expand Down
4 changes: 2 additions & 2 deletions bayesflow/networks/deep_set/deep_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
super().__init__(**keras_kwargs(kwargs))

# Stack of equivariant modules for a many-to-many learnable transformation
self.equivariant_modules = keras.Sequential(name="EquivariantStack")
self.equivariant_modules = keras.Sequential()
for i in range(depth):
equivariant_module = EquivariantModule(
num_dense_equivariant=num_dense_equivariant,
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
)

# Output linear layer to project set representation down to "summary_dim" learned summary statistics
self.output_projector = layers.Dense(summary_dim, activation="linear", name="OutputLayer")
self.output_projector = layers.Dense(summary_dim, activation="linear")
self.summary_dim = summary_dim

def call(self, input_set: Tensor, **kwargs) -> Tensor:
Expand Down
4 changes: 2 additions & 2 deletions bayesflow/networks/inference_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def sample(self, num_samples: int, conditions: Tensor = None, **kwargs) -> Tenso
samples = self(samples, conditions=conditions, inverse=True, density=False, **kwargs)
return samples

def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
_, log_density = self(samples, conditions=conditions, inverse=False, density=True, **kwargs)
def log_prob(self, targets: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
_, log_density = self(targets, conditions=conditions, inverse=False, density=True, **kwargs)
return log_density

def compute_metrics(self, data: dict[str, Tensor], stage: str = "training") -> dict[str, Tensor]:
Expand Down
8 changes: 6 additions & 2 deletions bayesflow/networks/summary_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@


class SummaryNetwork(keras.Layer):
def call(self, data: dict[str, Tensor], stage: str = "training") -> Tensor:
def call(self, data: Tensor, **kwargs) -> Tensor:
raise NotImplementedError

def compute_metrics(self, data: dict[str, Tensor], stage: str = "training") -> dict[str, Tensor]:
outputs = self(data, stage=stage)
summary_inputs = data["summary_variables"]
if data.get("summary_conditions") is not None:
summary_inputs = keras.ops.concatenate([summary_inputs, data["summary_conditions"]], axis=-1)

outputs = self(summary_inputs, training=stage == "training")

if any(self.metrics):
# TODO: what should we do here?
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
stack_dicts,
)

from .tensor_utils import repeat_tensor
from .tensor_utils import repeat_tensor, process_output

from .git import (
issue_url,
Expand Down
26 changes: 25 additions & 1 deletion bayesflow/utils/tensor_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,37 @@
import logging

import keras

from bayesflow.types import Tensor


def repeat_tensor(tensor: Tensor, num_repeats: int, axis=1):
def repeat_tensor(tensor: Tensor, num_repeats: int, axis=1) -> Tensor:
"""Utility function to repeat a tensor over a given axis ``num_repeats`` times."""

tensor = keras.ops.expand_dims(tensor, axis=axis)
repeats = [1] * tensor.ndim
repeats[axis] = num_repeats
repeated_tensor = keras.ops.tile(tensor, repeats=repeats)
return repeated_tensor


def process_output(outputs: Tensor, convert_to_numpy: bool = True) -> Tensor:
"""Utility function to apply common post-processing steps to the outputs of an approximator."""

# Remove trailing first axis for single data sets
if keras.ops.shape(outputs)[0] == 1:
outputs = keras.ops.squeeze(outputs, axis=0)

# Warn if any NaNs present in output
nan_mask = keras.ops.isnan(outputs)
if keras.ops.any(nan_mask):
logging.warning(f"A total of {keras.ops.sum(nan_mask)} NaNs found in output.")

# Warn if any inf present in output
inf_mask = keras.ops.isinf(outputs)
if keras.ops.any(inf_mask):
logging.warning(f"A total of {keras.ops.sum(inf_mask)} inf values found in output.")

if convert_to_numpy:
return outputs.numpy()
return outputs

0 comments on commit 8313c7d

Please sign in to comment.