Skip to content

Commit

Permalink
run ruff linter + formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jun 20, 2024
1 parent f604862 commit 2f0fdcc
Show file tree
Hide file tree
Showing 94 changed files with 328 additions and 364 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ docs/
.tox

# MacOS
.DS_Store
.DS_Store
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ preferred-citation:
type: article
url: "https://joss.theoj.org/papers/10.21105/joss.05702"
volume: 8
title: "BayesFlow: Amortized Bayesian Workflows With Neural Networks"
title: "BayesFlow: Amortized Bayesian Workflows With Neural Networks"
1 change: 0 additions & 1 deletion bayesflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from . import (
approximators,
configurators,
Expand Down
1 change: 0 additions & 1 deletion bayesflow/approximators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@

from .approximator import Approximator
9 changes: 5 additions & 4 deletions bayesflow/approximators/approximator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import keras
from keras.saving import register_keras_serializable

Expand All @@ -20,7 +19,7 @@
@register_keras_serializable(package="bayesflow.amortizers")
class Approximator(BaseApproximator):
def __init__(self, **kwargs):
""" The main workhorse for learning amortized neural approximators for distributions arising
"""The main workhorse for learning amortized neural approximators for distributions arising
in inverse problems and Bayesian inference (e.g., posterior distributions, likelihoods, marginal
likelihoods).
Expand Down Expand Up @@ -64,14 +63,16 @@ def __init__(self, **kwargs):
if "configurator" not in kwargs:
# try to set up a default configurator
if "inference_variables" not in kwargs:
raise ValueError(f"You must specify either a configurator or arguments for the default configurator.")
raise ValueError("You must specify either a configurator or arguments for the default configurator.")

inference_variables = kwargs.pop("inference_variables")
inference_conditions = kwargs.pop("inference_conditions", None)
summary_variables = kwargs.pop("summary_variables", None)
summary_conditions = kwargs.pop("summary_conditions", None)

kwargs["configurator"] = Configurator(inference_variables, inference_conditions, summary_variables, summary_conditions)
kwargs["configurator"] = Configurator(
inference_variables, inference_conditions, summary_variables, summary_conditions
)

kwargs.setdefault("summary_network", None)
super().__init__(**kwargs)
33 changes: 22 additions & 11 deletions bayesflow/approximators/base_approximator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import keras
from keras.saving import (
deserialize_keras_object,
Expand All @@ -15,15 +14,23 @@

@register_keras_serializable(package="bayesflow.approximators")
class BaseApproximator(keras.Model):
def __init__(self, inference_network: InferenceNetwork, summary_network: SummaryNetwork, configurator: BaseConfigurator, **kwargs):
def __init__(
self,
inference_network: InferenceNetwork,
summary_network: SummaryNetwork,
configurator: BaseConfigurator,
**kwargs,
):
super().__init__(**keras_kwargs(kwargs))
self.inference_network = inference_network
self.summary_network = summary_network
self.configurator = configurator

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "BaseApproximator":
config["inference_network"] = deserialize_keras_object(config["inference_network"], custom_objects=custom_objects)
config["inference_network"] = deserialize_keras_object(
config["inference_network"], custom_objects=custom_objects
)
config["summary_network"] = deserialize_keras_object(config["summary_network"], custom_objects=custom_objects)
config["configurator"] = deserialize_keras_object(config["configurator"], custom_objects=custom_objects)

Expand Down Expand Up @@ -63,10 +70,12 @@ def evaluate(self, *args, **kwargs):

if val_logs is None:
# https://github.com/keras-team/keras/issues/19835
warnings.warn(f"Found no validation logs due to a bug in keras. "
f"Applying workaround, but incorrect loss values may be logged. "
f"If possible, increase the size of your dataset, "
f"or lower the number of validation steps used.")
warnings.warn(
"Found no validation logs due to a bug in keras. "
"Applying workaround, but incorrect loss values may be logged. "
"If possible, increase the size of your dataset, "
"or lower the number of validation steps used."
)

val_logs = {}

Expand Down Expand Up @@ -103,16 +112,18 @@ def compute_metrics(self, data: dict[str, Tensor], stage: str = "training") -> d
return metrics | summary_metrics | inference_metrics

def compute_loss(self, *args, **kwargs):
raise RuntimeError(f"Use compute_metrics()['loss'] instead.")
raise RuntimeError("Use compute_metrics()['loss'] instead.")

def fit(self, *args, **kwargs):
if not self.built:
try:
dataset = kwargs.get("x") or args[0]
self.build_from_data(dataset[0])
except Exception:
raise RuntimeError(f"Could not automatically build the approximator. Please pass a dataset as the "
f"first argument to `approximator.fit()` or manually call `approximator.build()` "
f"with a dictionary specifying your data shapes.")
raise RuntimeError(
"Could not automatically build the approximator. Please pass a dataset as the "
"first argument to `approximator.fit()` or manually call `approximator.build()` "
"with a dictionary specifying your data shapes."
)

return super().fit(*args, **kwargs)
22 changes: 16 additions & 6 deletions bayesflow/approximators/jax_approximator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import jax
import keras

Expand All @@ -13,7 +12,14 @@ def train_step(self, *args, **kwargs):
def test_step(self, *args, **kwargs):
return self.stateless_test_step(*args, **kwargs)

def stateless_compute_metrics(self, trainable_variables: any, non_trainable_variables: any, metrics_variables: any, data: dict[str, Tensor], stage: str = "training") -> (Tensor, tuple):
def stateless_compute_metrics(
self,
trainable_variables: any,
non_trainable_variables: any,
metrics_variables: any,
data: dict[str, Tensor],
stage: str = "training",
) -> (Tensor, tuple):
"""
Things we do for jax:
1. Accept trainable variables as the first argument
Expand Down Expand Up @@ -47,11 +53,13 @@ def stateless_train_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[s

grad_fn = jax.value_and_grad(self.stateless_compute_metrics, has_aux=True)

(loss, aux), grads = grad_fn(trainable_variables, non_trainable_variables, metrics_variables, data, stage="training")
(loss, aux), grads = grad_fn(
trainable_variables, non_trainable_variables, metrics_variables, data, stage="training"
)
metrics, non_trainable_variables, metrics_variables = aux

trainable_variables, optimizer_variables = (
self.optimizer.stateless_apply(optimizer_variables, grads, trainable_variables)
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)

metrics_variables = self._update_loss(loss, metrics_variables)
Expand All @@ -62,7 +70,9 @@ def stateless_train_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[s
def stateless_test_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[str, Tensor], tuple):
trainable_variables, non_trainable_variables, metrics_variables = state

loss, aux = self.stateless_compute_metrics(trainable_variables, non_trainable_variables, metrics_variables, data, stage="validation")
loss, aux = self.stateless_compute_metrics(
trainable_variables, non_trainable_variables, metrics_variables, data, stage="validation"
)
metrics, non_trainable_variables, metrics_variables = aux

metrics_variables = self._update_loss(loss, metrics_variables)
Expand Down
5 changes: 1 addition & 4 deletions bayesflow/approximators/numpy_approximator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@

import numpy as np

from bayesflow.types import Tensor

from .base_approximator import BaseApproximator


class NumpyApproximator(BaseApproximator):
def train_step(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
raise NotImplementedError(f"Keras currently has no support for numpy training.")
raise NotImplementedError("Keras currently has no support for numpy training.")
1 change: 0 additions & 1 deletion bayesflow/approximators/tensorflow_approximator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import tensorflow as tf

from bayesflow.types import Tensor
Expand Down
1 change: 0 additions & 1 deletion bayesflow/approximators/torch_approximator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import torch

from .base_approximator import BaseApproximator
Expand Down
1 change: 0 additions & 1 deletion bayesflow/configurators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@

from .base_configurator import BaseConfigurator
from .configurator import Configurator
1 change: 0 additions & 1 deletion bayesflow/configurators/base_configurator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from keras.saving import register_keras_serializable

from bayesflow.types import Tensor
Expand Down
5 changes: 2 additions & 3 deletions bayesflow/configurators/configurator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from keras.saving import register_keras_serializable

from bayesflow.types import Tensor
Expand All @@ -14,7 +13,7 @@ def __init__(
inference_variables: list[str],
inference_conditions: list[str] = None,
summary_variables: list[str] = None,
summary_conditions: list[str] = None
summary_conditions: list[str] = None,
):
self.inference_variables = inference_variables
self.inference_conditions = inference_conditions or []
Expand All @@ -30,7 +29,7 @@ def get_config(self) -> dict:
"inference_variables": self.inference_variables,
"inference_conditions": self.inference_conditions,
"summary_variables": self.summary_variables,
"summary_conditions": self.summary_conditions
"summary_conditions": self.summary_conditions,
}

def configure_inference_variables(self, data: dict[str, Tensor]) -> Tensor | None:
Expand Down
2 changes: 0 additions & 2 deletions bayesflow/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@

from .offline_dataset import OfflineDataset
from .online_dataset import OnlineDataset
from .rounds_dataset import RoundsDataset

6 changes: 3 additions & 3 deletions bayesflow/datasets/offline_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import keras
import math

Expand All @@ -7,6 +6,7 @@ class OfflineDataset(keras.utils.PyDataset):
"""
A dataset that is pre-simulated and stored in memory.
"""

def __init__(self, data: dict, batch_size: int, **kwargs):
super().__init__(**kwargs)
self.batch_size = batch_size
Expand All @@ -17,7 +17,7 @@ def __init__(self, data: dict, batch_size: int, **kwargs):
self.shuffle()

def __getitem__(self, item: int) -> (dict, dict):
""" Get a batch of pre-simulated data """
"""Get a batch of pre-simulated data"""
item = slice(item * self.batch_size, (item + 1) * self.batch_size)
item = self.indices[item]

Expand All @@ -30,5 +30,5 @@ def on_epoch_end(self) -> None:
self.shuffle()

def shuffle(self) -> None:
""" Shuffle the dataset in-place. """
"""Shuffle the dataset in-place."""
self.indices = keras.random.shuffle(self.indices)
3 changes: 2 additions & 1 deletion bayesflow/datasets/online_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import keras

from bayesflow.simulators.simulator import Simulator
Expand All @@ -8,12 +7,14 @@ class OnlineDataset(keras.utils.PyDataset):
"""
A dataset that is generated on-the-fly.
"""

def __init__(self, simulator: Simulator, batch_size: int, **kwargs):
super().__init__(**kwargs)

if kwargs.get("use_multiprocessing"):
# keras workaround: https://github.com/keras-team/keras/issues/19346
import multiprocessing as mp

mp.set_start_method("spawn", force=True)

self.simulator = simulator
Expand Down
7 changes: 4 additions & 3 deletions bayesflow/datasets/rounds_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import keras

from bayesflow.simulators.simulator import Simulator
Expand All @@ -8,12 +7,14 @@ class RoundsDataset(keras.utils.PyDataset):
"""
A dataset that is generated on-the-fly at the beginning of every n-th epoch.
"""

def __init__(self, simulator: Simulator, batch_size: int, batches_per_epoch: int, epochs_per_round: int, **kwargs):
super().__init__(**kwargs)

if kwargs.get("use_multiprocessing"):
# keras workaround: https://github.com/keras-team/keras/issues/19346
import multiprocessing as mp

mp.set_start_method("spawn", force=True)

self.simulator = simulator
Expand All @@ -27,7 +28,7 @@ def __init__(self, simulator: Simulator, batch_size: int, batches_per_epoch: int
self.regenerate()

def __getitem__(self, item: int) -> (dict, dict):
""" Get a batch of pre-simulated data """
"""Get a batch of pre-simulated data"""
return self.data[item]

@property
Expand All @@ -41,5 +42,5 @@ def on_epoch_end(self) -> None:
self.regenerate()

def regenerate(self) -> None:
""" Sample new batches of data from the joint distribution unconditionally """
"""Sample new batches of data from the joint distribution unconditionally"""
self.data = [self.simulator.sample((self.batch_size,)) for _ in range(self.batches_per_epoch)]
1 change: 0 additions & 1 deletion bayesflow/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@

from .distribution import Distribution
from .diagonal_normal import DiagonalNormal
2 changes: 1 addition & 1 deletion bayesflow/distributions/diagonal_normal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import math

import keras
Expand All @@ -16,6 +15,7 @@ class DiagonalNormal(Distribution):
- ``_log_unnormalized_prob`` method is used as a loss function
- ``log_prob`` is used for density computation
"""

def __init__(self, mean: float | Tensor = 0.0, std: float | Tensor = 1.0, **kwargs):
super().__init__(**kwargs)
self.mean = mean
Expand Down
1 change: 0 additions & 1 deletion bayesflow/distributions/distribution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import keras

from bayesflow.types import Shape, Tensor
Expand Down
6 changes: 2 additions & 4 deletions bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@

from .coupling_flow import CouplingFlow
from .deep_set import DeepSet
from .flow_matching import FlowMatching
from .inference_network import InferenceNetwork
from .mlp import MLP
from .resnet import ResNet
from .lstnet import LSTNet
from .summary_network import SummaryNetwork
from .transformers import SetTransformer

from .inference_network import InferenceNetwork
from .summary_network import SummaryNetwork
1 change: 0 additions & 1 deletion bayesflow/networks/coupling_flow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@

from .coupling_flow import CouplingFlow
10 changes: 5 additions & 5 deletions bayesflow/networks/coupling_flow/actnorm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from keras import ops
from keras.saving import register_keras_serializable

Expand All @@ -14,21 +13,22 @@ class ActNorm(InvertibleLayer):
Activation Normalization is learned invertible normalization, using
a Scale (s) and Bias (b) vector::
y = s * x + b (forward)
x = (y - b) / s (inverse)
y = s * x + b(forward)
x = (y - b) / s(inverse)
References
----------
.. [1] Kingma, D. P., & Dhariwal, P. (2018).
Glow: Generative flow with invertible 1x1 convolutions.
.. [1] Kingma, D. P., & Dhariwal, P. (2018).
Glow: Generative flow with invertible 1x1 convolutions.
Advances in Neural Information Processing Systems, 31.
.. [2] Salimans, Tim, and Durk P. Kingma. (2016).
Weight normalization: A simple reparameterization to accelerate
training of deep neural networks.
Advances in Neural Information Processing Systems, 29.
"""

def __init__(self, **kwargs):
super().__init__(**keras_kwargs(kwargs))
self.scale = None
Expand Down
Loading

0 comments on commit 2f0fdcc

Please sign in to comment.