Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras 3: Streamlined Backend #159

Draft
wants to merge 504 commits into
base: dev
Choose a base branch
from
Draft

Keras 3: Streamlined Backend #159

wants to merge 504 commits into from

Conversation

LarsKue
Copy link
Collaborator

@LarsKue LarsKue commented Apr 15, 2024

Work In Progress

This PR is still in progress. Please discuss open issues and raise possible concerns below.

Summary

  • Streamline the definition of Prior and other distributions
  • When sampling parameters, they are now always passed by name in a dictionary
  • Add decorators for easy construction and batching of backend-conforming distribution objects
@bf.distribution
def prior():
    return {"theta": np.random.normal(size=2)}

batch_of_samples = prior.sample((32,))
assert batch_of_samples["theta"].shape == (32, 2)
  • Amortizers are now Keras 3 Models, which allows backend-agnostic training
# set backend to "tensorflow", "jax", or "torch" before importing bayesflow
import os
os.environ["KERAS_BACKEND"] = "torch"

import bayesflow as bf

amortizer = bf.AmortizedPosterior(...)
amortizer.fit()
  • Move behavior of Configurator into Amortizer
  • Add default configuration behavior for edge cases, like no summary network
  • General overhaul of the data flow inside Amortizer
  • Add a Dataset object that takes care of data loading in multiple worker processes
  • Move Training Strategy from Trainer into Dataset
# online
dataset = bf.datasets.OnlineDataset(generative_model=..., workers=12, use_multiprocessing=True)
amortizer.fit(dataset, steps_per_epoch=1000)
# offline, in memory
data = {...}  # some dictionary
dataset = OfflineDataset(data, workers=12, use_multiprocessing=True)
amortizer.fit(dataset)
# offline, on disk
import keras

class MyDataset(keras.utils.PyDataset):
    ...  # user-implemented data loading from disk

dataset = MyDataset(workers=12, use_multiprocessing=True)
amortizer.fit(dataset)

In Progress

  • Allow the user to specify what variables are observed vs. inferred vs. conditioned on by name
# two moons
posterior = Amortizer(
    inference_network,
    observed_variables=["x1", "x2"],
    inferred_variables=["theta1", "theta2"],
    inference_conditions=["r", "alpha"]
)
  • Named arguments with distribution decorators (@LarsKue)
@bf.Prior(is_conditional=True)
def prior(alpha, beta):
    return {"a": 2 * alpha, "b": alpha + beta}
  • Update Documentation (@LarsKue, mostly done)
  • Implement hierarchical amortizers for multi-level models (@daniel-habermann)
  • Port existing networks to Keras 3 (@Chase-Grajeda)
  • Add support for non-batchable context and make this the default (@LarsKue)

Postponed

We should probably do these things after merging with dev and before merging with main:

  • Add support for graph structured priors
  • Split sampled data by parameter names and return as a dict (@jerrymhuang)
  • Constrain predicted parameters to user-defined subspaces (@Kucharssim)
  • Add coverage statistics to README.md with a workflow
  • Expand test coverage to at least 75%
  • Update example notebooks
settings = bf.settings.propose(
    training="offline",
    dataset_size=600_000,
    data_shape=(200, 2),
    data_type="time series",
    parameters_shape=(8,)
)

In Discussion

  • Add a WorkFlow (name wip) object that encapsulates both amortizer and dataset for easier post-processing and model sharing
  • Rename batchable / non-batchable context

Dropped

  • Allow the user to pass a dictionary of distributions instead of defining a prior
import tensorflow_probability as tfp
D = tfp.distributions
prior = {"theta": D.Normal([0.0, 0.0], [1.0, 1.0])}

Reason: We should enforce a single way to do things. Also poor support with pure jax.

  • Return sampled data as a DataFrame

Reason: Too restrictive for data structure.

@LarsKue LarsKue added refactoring Some code shall be redesigned unit tests A new set of tests needs to be added. labels Apr 15, 2024
@LarsKue LarsKue self-assigned this Apr 15, 2024
@paul-buerkner
Copy link
Collaborator

paul-buerkner commented Apr 15, 2024

@LarsKue Thank you so much! This already looks amazing!

Could you perhaps add a simple fully runnable example here for people to get started playing around with it? It is kind of there above, but I think it would make things easier to have one chunk of example code to copy and edit from there.

Everyone, please try out the new interface and tell us what you think!

@LarsKue
Copy link
Collaborator Author

LarsKue commented Apr 15, 2024

@paul-buerkner Yes, I am working on it. I hope I have one ready today.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@marvinschmitt
Copy link
Collaborator

Great work!! 👏

I'll write down some thoughts on the installation process. Those don't need any changes in the streamlined codebase but are just reminders for our future selves shortly before the release.

  • In addition to keras (which will replace the current tensorflow dependencies), the user has to install their favorite backend. How do we approach this for users who aren't proficient in Python-ecosystem stuff? A few initial thoughts:
    • Extras like pip install bayesflow[torch]. Advantage: Easy interface. Disadvantage: Restricted to pip, no mamba equivalent. I don't like that option.
    • Message during the bayesflow installation. That's annoying (if possible at all?) and I wouldn't do that either.
    • After installation, upon running bayesflow: Catch any errors that relate to missing backend packages and provide a comprehensive error message with concrete pointers on how to install the necessary backend to fix the issue. Currently my favorite option.
  • Python >=3.11 is required for typing.Self

LarsKue and others added 30 commits July 10, 2024 19:11
Also rename FunctionalSimulator to LambdaSimulator
this makes the implementation more explicit, which is easier to debug
but also means a little bit more code clutter
all in all I think this is better
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactoring Some code shall be redesigned unit tests A new set of tests needs to be added.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants