Skip to content

Commit

Permalink
Merge branch 'streamlined-backend' of https://github.com/stefanradev9…
Browse files Browse the repository at this point in the history
…3/BayesFlow into streamlined-backend
  • Loading branch information
stefanradev93 committed Jun 14, 2024
2 parents eab5427 + 68fe48c commit c49760d
Show file tree
Hide file tree
Showing 18 changed files with 231 additions and 156 deletions.
2 changes: 1 addition & 1 deletion bayesflow/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
diagnostics,
distributions,
networks,
simulation,
simulators,
)

from .approximators import Approximator
Expand Down
8 changes: 3 additions & 5 deletions bayesflow/experimental/datasets/offline_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,26 @@
import keras
import math

from bayesflow.experimental.utils import nested_getitem


class OfflineDataset(keras.utils.PyDataset):
"""
A dataset that is pre-simulated and stored in memory.
"""
# TODO: fix
def __init__(self, data: dict, batch_size: int, **kwargs):
super().__init__(**kwargs)
self.batch_size = batch_size

self.data = data
self.indices = keras.ops.arange(len(data[next(iter(data.keys()))]))
self.indices = keras.ops.arange(len(data[next(iter(data.keys()))]), dtype="int64")

self.shuffle()

def __getitem__(self, item: int) -> (dict, dict):
""" Get a batch of pre-simulated data """
item = slice(item * self.batch_size, (item + 1) * self.batch_size)
item = self.indices[item]
return nested_getitem(self.data, item)

return {key: keras.ops.take(value, item, axis=0) for key, value in self.data.items()}

def __len__(self) -> int:
return math.ceil(len(self.indices) / self.batch_size)
Expand Down
15 changes: 10 additions & 5 deletions bayesflow/experimental/datasets/online_dataset.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@

import keras

from bayesflow.experimental.simulation import JointDistribution
from bayesflow.experimental.simulators import Simulator


class OnlineDataset(keras.utils.PyDataset):
"""
A dataset that is generated on-the-fly.
"""
def __init__(self, distribution, batch_size: int, **kwargs):
def __init__(self, simulator: Simulator, batch_size: int, **kwargs):
super().__init__(**kwargs)
self.distribution = distribution

if kwargs.get("use_multiprocessing"):
# keras workaround: https://github.com/keras-team/keras/issues/19346
import multiprocessing as mp
mp.set_start_method("spawn", force=True)

self.simulator = simulator
self.batch_size = batch_size

def __getitem__(self, item: int) -> (dict, dict):
""" Sample a batch of data from the joint distribution unconditionally """
return self.distribution.sample((self.batch_size,))
return self.simulator.sample((self.batch_size,))

@property
def num_batches(self):
Expand Down
14 changes: 10 additions & 4 deletions bayesflow/experimental/datasets/rounds_dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@

import keras

from bayesflow.experimental.simulation import JointDistribution
from bayesflow.experimental.simulators import Simulator


class RoundsDataset(keras.utils.PyDataset):
"""
A dataset that is generated on-the-fly at the beginning of every n-th epoch.
"""
def __init__(self, joint_distribution: JointDistribution, batch_size: int, batches_per_epoch: int, epochs_per_round: int, **kwargs):
def __init__(self, simulator: Simulator, batch_size: int, batches_per_epoch: int, epochs_per_round: int, **kwargs):
super().__init__(**kwargs)
self.joint_distribution = joint_distribution

if kwargs.get("use_multiprocessing"):
# keras workaround: https://github.com/keras-team/keras/issues/19346
import multiprocessing as mp
mp.set_start_method("spawn", force=True)

self.simulator = simulator
self.batch_size = batch_size
self.batches_per_epoch = batches_per_epoch
self.epochs_per_round = epochs_per_round
Expand All @@ -36,4 +42,4 @@ def on_epoch_end(self) -> None:

def regenerate(self) -> None:
""" Sample new batches of data from the joint distribution unconditionally """
self.data = [self.joint_distribution.sample((self.batch_size,)) for _ in range(self.batches_per_epoch)]
self.data = [self.simulator.sample((self.batch_size,)) for _ in range(self.batches_per_epoch)]
3 changes: 0 additions & 3 deletions bayesflow/experimental/simulation/__init__.py

This file was deleted.

2 changes: 0 additions & 2 deletions bayesflow/experimental/simulation/decorators/__init__.py

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

3 changes: 3 additions & 0 deletions bayesflow/experimental/simulators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

from .sequential_simulator import SequentialSimulator
from .simulator import Simulator
58 changes: 58 additions & 0 deletions bayesflow/experimental/simulators/sequential_simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

import keras

from typing import Sequence

from bayesflow.experimental.types import Sampler, Shape, Tensor
from bayesflow.experimental.utils import batched_call

from .simulator import Simulator


class SequentialSimulator(Simulator):
r"""
Implements a sequentially factorized simulator:
.. math::
p(x) = \prod_{i = 1}^{n - 1} p(x_{i} | x_{i + 1}, ..., x_{n}) p(x_{n}
Examples:
>>> import numpy as np
>>> def sample_contexts():
>>> return dict(contexts=np.random.normal())
>>> def sample_parameters(shape: Shape, **kwargs):
>>> return dict(parameters=np.random.normal())
>>> def sample_observables(contexts: Tensor, parameters: Tensor, **kwargs):
>>> observables = contexts + parameters + np.random.normal()
>>> return dict(observables=observables)
>>> simulator = SequentialSimulator([sample_contexts, sample_parameters, sample_observables])
>>> simulator.sample((2,))
{'contexts': tensor(..., shape=(2, 1)), 'parameters': tensor(..., shape=(2, 1)), 'observables': tensor(..., shape=(2, 1))}
"""
def __init__(self, samplers: Sequence[Sampler]):
super().__init__()
self.samplers = list(samplers)

def sample(self, shape: Shape) -> dict[str, Tensor]:
data = {}

for sampler in self.samplers:
try:
data |= batched_call(sampler, shape[0], **data)
except TypeError as e:
if keras.backend.backend() == "torch" and "device" in str(e):
raise RuntimeError(f"Encountered an unexpected device error when sampling. "
f"This can happen when you use numpy in conjunction with automatic "
f"vectorization for samplers with arguments. Note that the arguments passed "
f"to the samplers are always tensors, which may live on the GPU. "
f"Performing numpy operations on these is prohibited.") from e
else:
raise e

for key, value in data.items():
if keras.ops.ndim(value) == 1:
data[key] = keras.ops.expand_dims(value, -1)

return data
7 changes: 7 additions & 0 deletions bayesflow/experimental/simulators/simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

from bayesflow.experimental.types import Shape, Tensor


class Simulator:
def sample(self, batch_shape: Shape) -> dict[str, Tensor]:
raise NotImplementedError
12 changes: 12 additions & 0 deletions bayesflow/experimental/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@

from typing import Callable


Shape = tuple[int, ...]

# this is ugly, but:
Expand All @@ -13,3 +17,11 @@
except ModuleNotFoundError:
import torch
Tensor: type(torch.Tensor) = torch.Tensor


BatchedConditionalSampler = Callable[[Shape, Tensor, ...], dict[str, Tensor]]
BatchedUnconditionalSampler = Callable[[Shape], dict[str, Tensor]]
UnbatchedConditionalSampler = Callable[[Tensor, ...], dict[str, Tensor]]
UnbatchedUnconditionalSampler = Callable[[], dict[str, Tensor]]

Sampler = BatchedConditionalSampler | UnbatchedConditionalSampler | UnbatchedConditionalSampler | UnbatchedUnconditionalSampler
2 changes: 1 addition & 1 deletion bayesflow/experimental/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

from .dictutils import (
batched_call,
filter_concatenate,
keras_kwargs,
nested_getitem,
)

from .dispatch import (
Expand Down
69 changes: 58 additions & 11 deletions bayesflow/experimental/utils/dictutils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,67 @@

from keras import ops
import inspect
import keras

from typing import Sequence

from bayesflow.experimental.types import Tensor


def nested_getitem(data: dict, item: int) -> dict:
""" Get the item-th element from a nested dictionary. """
result = {}
for key, value in data.items():
if isinstance(value, dict):
result[key] = nested_getitem(value, item)
else:
result[key] = value[item]
return result
def convert_kwargs(f, *args, **kwargs) -> dict[str, any]:
""" Convert positional and keyword arguments to just keyword arguments for f """
if not args:
return kwargs

signature = inspect.signature(f)

parameters = dict(zip(signature.parameters, args))

for name, value in kwargs.items():
if name in parameters:
raise TypeError(f"{f.__name__}() got multiple arguments for argument '{name}'")

parameters[name] = value

return parameters


def convert_args(f, *args, **kwargs) -> tuple[any, ...]:
""" Convert positional and keyword arguments to just positional arguments for f """
if not kwargs:
return args

signature = inspect.signature(f)

# convert to just kwargs first
kwargs = convert_kwargs(f, *args, **kwargs)

parameters = []
for name, param in signature.parameters.items():
if param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]:
continue

parameters.append(kwargs.get(name, param.default))

return tuple(parameters)


def batched_call(f, batch_size, *args, **kwargs):
""" Call f, automatically vectorizing to batch_size if required """
try:
data = f((batch_size,), *args, **kwargs)
data = {key: keras.ops.convert_to_tensor(value) for key, value in data.items()}
return data
except TypeError:
pass

def vectorized(elements):
data = f(*elements[1:])
data = {key: keras.ops.convert_to_tensor(value) for key, value in data.items()}
return data

args = convert_args(f, *args, **kwargs)
dummy = keras.ops.zeros((batch_size, 0))
return keras.ops.vectorized_map(vectorized, (dummy, *args))


def filter_concatenate(data: dict[str, Tensor], keys: Sequence[str], axis: int = -1) -> Tensor:
Expand All @@ -28,7 +75,7 @@ def filter_concatenate(data: dict[str, Tensor], keys: Sequence[str], axis: int =
tensors = [data[key] for key in keys]

try:
return ops.concatenate(tensors, axis=axis)
return keras.ops.concatenate(tensors, axis=axis)
except ValueError as e:
shapes = [t.shape for t in tensors]
raise ValueError(f"Cannot trivially concatenate tensors {keys} with shapes {shapes}") from e
Expand Down
Loading

0 comments on commit c49760d

Please sign in to comment.