Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/streamlined-backend' into stream…
Browse files Browse the repository at this point in the history
…lined-backend
  • Loading branch information
LarsKue committed Jun 18, 2024
2 parents 3150d11 + e2355de commit c8060fc
Show file tree
Hide file tree
Showing 37 changed files with 547 additions and 179 deletions.
File renamed without changes.
File renamed without changes.
18 changes: 14 additions & 4 deletions .github/workflows/tests.yml → .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,29 @@ jobs:
- name: Run JAX Tests
if: ${{ matrix.backend == 'jax' }}
run: |
python -m pytest tests/ -v -m "not (numpy or tensorflow or torch)"
pytest --cov bayesflow tests/ -v -m "not (numpy or tensorflow or torch)"
- name: Run NumPy Tests
if: ${{ matrix.backend == 'numpy' }}
run: |
python -m pytest tests/ -v -m "not (jax or tensorflow or torch)"
pytest --cov bayesflow tests/ -v -m "not (jax or tensorflow or torch)"
- name: Run TensorFlow Tests
if: ${{ matrix.backend == 'tensorflow' }}
run: |
python -m pytest tests/ -v -m "not (jax or numpy or torch)"
pytest --cov bayesflow tests/ -v -m "not (jax or numpy or torch)"
- name: Run PyTorch Tests
if: ${{ matrix.backend == 'torch' }}
run: |
python -m pytest tests/ -v -m "not (jax or numpy or tensorflow)"
pytest --cov bayesflow tests/ -v -m "not (jax or numpy or tensorflow)"
- name: Create Coverage Report
run: |
coverage xml
- name: Upload Coverage Reports to CodeCov
uses: codecov/codecov-action@v4
with:
# do not use the files attribute here, otherwise the reports are not merged correctly
token: ${{ secrets.CODECOV_TOKEN }}
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# BayesFlow <img src="img/bayesflow_hex.png" style="float: right; width: 20%; height: 20%;" alt="BayesFlow Logo" />

[![Actions Status](https://github.com/stefanradev93/bayesflow/workflows/Tests/badge.svg)](https://github.com/stefanradev93/bayesflow/actions)
[![License: MIT](https://img.shields.io/badge/License-MIT-red.svg)](https://opensource.org/licenses/MIT)
[![DOI](https://joss.theoj.org/papers/10.21105/joss.05702/status.svg)](https://doi.org/10.21105/joss.05702)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/dwyl/esta/issues)
![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/stefanradev93/bayesflow/tests.yaml?style=for-the-badge&label=Tests)
![Codecov](https://img.shields.io/codecov/c/github/stefanradev93/bayesflow/streamlined-backend?style=for-the-badge)
[![DOI](https://img.shields.io/badge/DOI-10.21105%2Fjoss.05702-blue?style=for-the-badge)](https://doi.org/10.21105/joss.05702)
![PyPI - License](https://img.shields.io/pypi/l/bayesflow?style=for-the-badge)

BayesFlow is a Python library for simulation-based **Amortized Bayesian Inference** with neural networks.
It provides users with:
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
diagnostics,
distributions,
networks,
simulation,
simulators,
)

from .approximators import Approximator
Expand Down
8 changes: 3 additions & 5 deletions bayesflow/experimental/datasets/offline_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,26 @@
import keras
import math

from bayesflow.experimental.utils import nested_getitem


class OfflineDataset(keras.utils.PyDataset):
"""
A dataset that is pre-simulated and stored in memory.
"""
# TODO: fix
def __init__(self, data: dict, batch_size: int, **kwargs):
super().__init__(**kwargs)
self.batch_size = batch_size

self.data = data
self.indices = keras.ops.arange(len(data[next(iter(data.keys()))]))
self.indices = keras.ops.arange(len(data[next(iter(data.keys()))]), dtype="int64")

self.shuffle()

def __getitem__(self, item: int) -> (dict, dict):
""" Get a batch of pre-simulated data """
item = slice(item * self.batch_size, (item + 1) * self.batch_size)
item = self.indices[item]
return nested_getitem(self.data, item)

return {key: keras.ops.take(value, item, axis=0) for key, value in self.data.items()}

def __len__(self) -> int:
return math.ceil(len(self.indices) / self.batch_size)
Expand Down
15 changes: 10 additions & 5 deletions bayesflow/experimental/datasets/online_dataset.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@

import keras

from bayesflow.experimental.simulation import JointDistribution
from bayesflow.experimental.simulators import Simulator


class OnlineDataset(keras.utils.PyDataset):
"""
A dataset that is generated on-the-fly.
"""
def __init__(self, distribution, batch_size: int, **kwargs):
def __init__(self, simulator: Simulator, batch_size: int, **kwargs):
super().__init__(**kwargs)
self.distribution = distribution

if kwargs.get("use_multiprocessing"):
# keras workaround: https://github.com/keras-team/keras/issues/19346
import multiprocessing as mp
mp.set_start_method("spawn", force=True)

self.simulator = simulator
self.batch_size = batch_size

def __getitem__(self, item: int) -> (dict, dict):
""" Sample a batch of data from the joint distribution unconditionally """
return self.distribution.sample((self.batch_size,))
return self.simulator.sample((self.batch_size,))

@property
def num_batches(self):
Expand Down
14 changes: 10 additions & 4 deletions bayesflow/experimental/datasets/rounds_dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@

import keras

from bayesflow.experimental.simulation import JointDistribution
from bayesflow.experimental.simulators import Simulator


class RoundsDataset(keras.utils.PyDataset):
"""
A dataset that is generated on-the-fly at the beginning of every n-th epoch.
"""
def __init__(self, joint_distribution: JointDistribution, batch_size: int, batches_per_epoch: int, epochs_per_round: int, **kwargs):
def __init__(self, simulator: Simulator, batch_size: int, batches_per_epoch: int, epochs_per_round: int, **kwargs):
super().__init__(**kwargs)
self.joint_distribution = joint_distribution

if kwargs.get("use_multiprocessing"):
# keras workaround: https://github.com/keras-team/keras/issues/19346
import multiprocessing as mp
mp.set_start_method("spawn", force=True)

self.simulator = simulator
self.batch_size = batch_size
self.batches_per_epoch = batches_per_epoch
self.epochs_per_round = epochs_per_round
Expand All @@ -36,4 +42,4 @@ def on_epoch_end(self) -> None:

def regenerate(self) -> None:
""" Sample new batches of data from the joint distribution unconditionally """
self.data = [self.joint_distribution.sample((self.batch_size,)) for _ in range(self.batches_per_epoch)]
self.data = [self.simulator.sample((self.batch_size,)) for _ in range(self.batches_per_epoch)]
6 changes: 4 additions & 2 deletions bayesflow/experimental/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from .coupling_flow import CouplingFlow
from .deep_set import DeepSet
from .flow_matching import FlowMatching
from .inference_network import InferenceNetwork
from .mlp import MLP
from .resnet import ResNet
from .lstnet import LSTNet
from .transformers import SetTransformer
from .summary_network import SummaryNetwork

from .inference_network import InferenceNetwork
from .summary_network import SummaryNetwork
2 changes: 2 additions & 0 deletions bayesflow/experimental/networks/lstnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .lstnet import LSTNet
88 changes: 88 additions & 0 deletions bayesflow/experimental/networks/lstnet/lstnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@

import keras
from keras import layers, Sequential
from keras.saving import register_keras_serializable

from bayesflow.experimental.types import Tensor
from bayesflow.experimental.utils import keras_kwargs

from .skip_recurrent import SkipRecurrentNet
from ...networks import MLP


@register_keras_serializable(package="bayesflow.networks.lstnet")
class LSTNet(keras.Model):
"""
Implements a LSTNet Architecture as described in [1]
[1] Y. Zhang and L. Mikelsons, Solving Stochastic Inverse Problems with Stochastic BayesFlow,
2023 IEEE/ASME International Conference on Advanced Intelligent Mechatronics (AIM),
Seattle, WA, USA, 2023, pp. 966-972, doi: 10.1109/AIM46323.2023.10196190.
TODO: Add proper docstring
"""

def __init__(
self,
summary_dim: int = 16,
filters: int | list | tuple = 32,
kernel_sizes: int | list | tuple = 3,
strides: int | list | tuple = 1,
activation: str = "relu",
kernel_initializer: str = "glorot_uniform",
groups: int = 8,
recurrent_type: str | keras.Layer = "gru",
recurrent_dim: int = 128,
bidirectional: bool = True,
dropout: float = 0.05,
skip_steps: int = 4,
**kwargs
):

super().__init__(**keras_kwargs(kwargs))

# Convolutional backbone -> can be extended with inception-like structure
if not isinstance(filters, (list, tuple)):
filters = (filters, )
if not isinstance(kernel_sizes, (list, tuple)):
kernel_sizes = (kernel_sizes, )
if not isinstance(strides, (list, tuple)):
strides = (strides, )
self.conv = Sequential()
for f, k, s in zip(filters, kernel_sizes, strides):
self.conv.add(
layers.Conv1D(
filters=f,
kernel_size=k,
strides=s,
activation=activation,
kernel_initializer=kernel_initializer,
)
)
self.conv.add(
layers.GroupNormalization(groups=groups)
)

# Recurrent and feedforward backbones
self.recurrent = SkipRecurrentNet(
hidden_dim=recurrent_dim,
recurrent_type=recurrent_type,
bidirectional=bidirectional,
input_channels=filters[-1],
skip_steps=skip_steps,
dropout=dropout
)
self.feedforward = MLP(**kwargs.get("mlp_kwargs", {}))

self.output_projector = layers.Dense(summary_dim)

def call(self, time_series: Tensor, **kwargs) -> Tensor:
summary = self.conv(time_series, **kwargs)
summary = self.recurrent(summary, **kwargs)
summary = self.feedforward(summary, **kwargs)
summary = self.output_projector(summary)
return summary

def build(self, input_shape):
self.call(keras.ops.zeros(input_shape))
58 changes: 58 additions & 0 deletions bayesflow/experimental/networks/lstnet/skip_recurrent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

import keras
from keras.saving import register_keras_serializable

from bayesflow.experimental.types import Tensor
from bayesflow.experimental.utils import keras_kwargs, find_recurrent_net

@register_keras_serializable(package="bayesflow.networks")
class SkipRecurrentNet(keras.Model):
"""
Implements a Skip recurrent layer as described in [1], but allowing a more flexible
recurrent backbone and a more flexible implementation.
[1] Y. Zhang and L. Mikelsons, Solving Stochastic Inverse Problems with Stochastic BayesFlow,
2023 IEEE/ASME International Conference on Advanced Intelligent Mechatronics (AIM),
Seattle, WA, USA, 2023, pp. 966-972, doi: 10.1109/AIM46323.2023.10196190.
TODO: Add proper docstring
"""
def __init__(
self,
hidden_dim: int = 256,
recurrent_type: str | keras.Layer = "gru",
bidirectional: bool = True,
input_channels: int = 64,
skip_steps: int = 4,
dropout: float = 0.05,
**kwargs
):
super().__init__(**keras_kwargs(kwargs))

recurrent_constructor = find_recurrent_net(recurrent_type)

self.recurrent = recurrent_constructor(
units=hidden_dim // 2 if bidirectional else hidden_dim,
dropout=dropout,
recurrent_dropout=dropout
)
self.skip_conv = keras.layers.Conv1D(
filters=input_channels*skip_steps,
kernel_size=skip_steps,
strides=skip_steps
)
self.skip_recurrent = recurrent_constructor(
units=hidden_dim // 2 if bidirectional else hidden_dim,
dropout=dropout,
recurrent_dropout=dropout
)
self.input_channels = input_channels

def call(self, time_series: Tensor, **kwargs) -> Tensor:
direct_summary = self.recurrent(time_series, **kwargs)
skip_summary = self.skip_recurrent(self.skip_conv(time_series), **kwargs)
return keras.ops.concatenate((direct_summary, skip_summary), axis=-1)

def build(self, input_shape):
self.call(keras.ops.zeros(input_shape))
4 changes: 1 addition & 3 deletions bayesflow/experimental/networks/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@


class SetTransformer:
pass
from .set_transformer import SetTransformer
2 changes: 1 addition & 1 deletion bayesflow/experimental/networks/transformers/isab.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.mab0 = MultiHeadAttentionBlock(**mab_kwargs)
self.mab1 = MultiHeadAttentionBlock(**mab_kwargs)

def call(self, set_x: Tensor, **kwargs):
def call(self, set_x: Tensor, **kwargs) -> Tensor:
"""Performs the forward pass through the self-attention layer.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/experimental/networks/transformers/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
self.feedforward.add(layers.Dense(output_dim))
self.ln_post = layers.LayerNormalization() if layer_norm else None

def call(self, set_x: Tensor, set_y: Tensor, **kwargs):
def call(self, set_x: Tensor, set_y: Tensor, **kwargs) -> Tensor:
"""Performs the forward pass through the attention layer.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/experimental/networks/transformers/pma.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
use_bias=use_bias
))

def call(self, set_x: Tensor, **kwargs):
def call(self, set_x: Tensor, **kwargs) -> Tensor:
"""Performs the forward pass through the PMA block.
Parameters
Expand Down
3 changes: 2 additions & 1 deletion bayesflow/experimental/networks/transformers/sab.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

from keras.saving import register_keras_serializable

from bayesflow.experimental.types import Tensor
from .mab import MultiHeadAttentionBlock


Expand All @@ -13,7 +14,7 @@ class SetAttentionBlock(MultiHeadAttentionBlock):
In International conference on machine learning (pp. 3744-3753). PMLR.
"""

def call(self, set_x, **kwargs):
def call(self, set_x: Tensor, **kwargs) -> Tensor:
"""Performs the forward pass through the self-attention layer.
Parameters
Expand Down
3 changes: 0 additions & 3 deletions bayesflow/experimental/simulation/__init__.py

This file was deleted.

2 changes: 0 additions & 2 deletions bayesflow/experimental/simulation/decorators/__init__.py

This file was deleted.

Loading

0 comments on commit c8060fc

Please sign in to comment.