-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'streamlined-backend' of https://github.com/stefanradev9…
…3/BayesFlow into streamlined-backend
- Loading branch information
Showing
18 changed files
with
231 additions
and
156 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
77 changes: 0 additions & 77 deletions
77
bayesflow/experimental/simulation/decorators/distribution_decorator.py
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
14 changes: 0 additions & 14 deletions
14
bayesflow/experimental/simulation/distributions/joint_distribution.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
|
||
from .sequential_simulator import SequentialSimulator | ||
from .simulator import Simulator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
|
||
import keras | ||
|
||
from typing import Sequence | ||
|
||
from bayesflow.experimental.types import Sampler, Shape, Tensor | ||
from bayesflow.experimental.utils import batched_call | ||
|
||
from .simulator import Simulator | ||
|
||
|
||
class SequentialSimulator(Simulator): | ||
r""" | ||
Implements a sequentially factorized simulator: | ||
.. math:: | ||
p(x) = \prod_{i = 1}^{n - 1} p(x_{i} | x_{i + 1}, ..., x_{n}) p(x_{n} | ||
Examples: | ||
>>> import numpy as np | ||
>>> def sample_contexts(): | ||
>>> return dict(contexts=np.random.normal()) | ||
>>> def sample_parameters(shape: Shape, **kwargs): | ||
>>> return dict(parameters=np.random.normal()) | ||
>>> def sample_observables(contexts: Tensor, parameters: Tensor, **kwargs): | ||
>>> observables = contexts + parameters + np.random.normal() | ||
>>> return dict(observables=observables) | ||
>>> simulator = SequentialSimulator([sample_contexts, sample_parameters, sample_observables]) | ||
>>> simulator.sample((2,)) | ||
{'contexts': tensor(..., shape=(2, 1)), 'parameters': tensor(..., shape=(2, 1)), 'observables': tensor(..., shape=(2, 1))} | ||
""" | ||
def __init__(self, samplers: Sequence[Sampler]): | ||
super().__init__() | ||
self.samplers = list(samplers) | ||
|
||
def sample(self, shape: Shape) -> dict[str, Tensor]: | ||
data = {} | ||
|
||
for sampler in self.samplers: | ||
try: | ||
data |= batched_call(sampler, shape[0], **data) | ||
except TypeError as e: | ||
if keras.backend.backend() == "torch" and "device" in str(e): | ||
raise RuntimeError(f"Encountered an unexpected device error when sampling. " | ||
f"This can happen when you use numpy in conjunction with automatic " | ||
f"vectorization for samplers with arguments. Note that the arguments passed " | ||
f"to the samplers are always tensors, which may live on the GPU. " | ||
f"Performing numpy operations on these is prohibited.") from e | ||
else: | ||
raise e | ||
|
||
for key, value in data.items(): | ||
if keras.ops.ndim(value) == 1: | ||
data[key] = keras.ops.expand_dims(value, -1) | ||
|
||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
|
||
from bayesflow.experimental.types import Shape, Tensor | ||
|
||
|
||
class Simulator: | ||
def sample(self, batch_shape: Shape) -> dict[str, Tensor]: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.