Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tutorial for training interface #983

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Add tutorial for training interface #983

wants to merge 2 commits into from

Conversation

michaeldeistler
Copy link
Contributor

@michaeldeistler michaeldeistler commented Mar 11, 2024

from sbi.neural_nets.flow import build_nsf
from sbi.neural_nets.density_estimators.flow import NFlowsFlow
from sbi.inference.posteriors import MCMCPosterior
from sbi.inference.potentials import likelihood_estimator_based_potential


# Build neural density estimator.
net = build_nsf(x, theta)
de = NFlowsFlow(net, condition_shape=(theta.shape[1],))

# Train the density estimator.
opt = Adam(list(de.parameters()), lr=5e-4)
for _ in range(100):
    opt.zero_grad()
    log_probs = de.log_prob(x, condition=theta)
    loss = -torch.mean(log_probs)
    loss.backward()
    opt.step()

# Build posterior and sample with MCMC.
potential, tf = likelihood_estimator_based_potential(de.net, prior, x_o)
posterior = MCMCPosterior(
    potential,
    proposal=prior,
    theta_transform=tf,
    num_chains=100,
    method="slice_np_vectorized"
)
samples = posterior.sample((1000,), x=x_o)
_ = pairplot(samples, limits=[[-3, 3], [-3, 3]], figsize=(3, 3))

TODO:

  • have an abstraction that avoids two lines for the density estimator (not sure about this one...)
  • Make the DirectPosterior and all potentials use the DensityEstimator abstraction
  • Make it easier to use NRE by having an abstraction for the loss

Copy link

codecov bot commented Mar 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 77.01%. Comparing base (afbd5e7) to head (d8db78c).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #983      +/-   ##
==========================================
- Coverage   85.08%   77.01%   -8.07%     
==========================================
  Files          90       90              
  Lines        6643     6643              
==========================================
- Hits         5652     5116     -536     
- Misses        991     1527     +536     
Flag Coverage Δ
unittests 77.01% <ø> (-8.07%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

see 23 files with indirect coverage changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant