Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: batched sampling for MCMC #1176

Open
wants to merge 80 commits into
base: main
Choose a base branch
from
Open

Conversation

manuelgloeckler
Copy link
Contributor

@manuelgloeckler manuelgloeckler commented Jun 18, 2024

What does this implement/fix? Explain your changes

This pull request aims to implement the sample_batched method for MCMC.

Current problem

  • BasePotential can either "allow_iid" or not. Hence, each batch dimension will be interpreted as IID samples.
    • Replace allow_iid with a mutable attribute (or optional input argument) interpret_as_iid.
    • Remove warning for batched x and default to batched evaluation
  • Refactor all MCMC initialization methods to work with batch dim.
    • resample should break
    • SIR should break
    • proposal should work
  • Add tests to check if correct samples are in each dimension (currently, only shapes are checked)
    • The problem is currently not catched by tests...

The current implementation will let you sample the correct shape, BUT will output the wrong solution. This is because the potential function will broadcast, repeat and finally sum up the first dimension which is incorrect.

manuelgloeckler and others added 30 commits April 29, 2024 09:04
…posteriors' into amortizedsample"

This reverts commit 07084e2, reversing
changes made to f16622d.
…from-different-posteriors' into amortizedsample
Copy link

codecov bot commented Jun 18, 2024

Codecov Report

Attention: Patch coverage is 80.67227% with 23 lines in your changes missing coverage. Please review.

Project coverage is 75.65%. Comparing base (2398a7a) to head (fd11a72).
Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1176      +/-   ##
==========================================
- Coverage   84.55%   75.65%   -8.91%     
==========================================
  Files          96       96              
  Lines        7603     7685      +82     
==========================================
- Hits         6429     5814     -615     
- Misses       1174     1871     +697     
Flag Coverage Δ
unittests 75.65% <80.67%> (-8.91%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
sbi/inference/posteriors/base_posterior.py 86.04% <100.00%> (ø)
...inference/potentials/likelihood_based_potential.py 100.00% <100.00%> (ø)
sbi/inference/potentials/ratio_based_potential.py 100.00% <100.00%> (ø)
sbi/utils/sbiutils.py 78.35% <ø> (-8.21%) ⬇️
sbi/utils/user_input_checks.py 80.64% <100.00%> (-2.87%) ⬇️
sbi/inference/abc/mcabc.py 15.87% <0.00%> (-68.26%) ⬇️
sbi/inference/abc/smcabc.py 12.44% <0.00%> (-69.96%) ⬇️
sbi/inference/potentials/base_potential.py 92.85% <85.71%> (+0.35%) ⬆️
.../inference/potentials/posterior_based_potential.py 95.00% <92.85%> (-1.97%) ⬇️
sbi/inference/posteriors/ensemble_posterior.py 50.00% <0.00%> (-37.97%) ⬇️
... and 2 more

... and 19 files with indirect coverage changes

@janfb janfb changed the title Amortized sample for MCMC feat: batched sampling for MCMC Jun 18, 2024
@gmoss13
Copy link
Contributor

gmoss13 commented Jun 27, 2024

I've made some progress now towards this PR, and would like some feedback before I continue.

BasePotential can either "allow_iid" or not.

Given batch_dim_theta!=batch_dim_x, we need to decide how to interpret how to evaluate potential(x,theta). We could return (batch_dim_x,batch_dim_theta) potentials (i.e. every combination), but I am worried this can add a lot of computational overhead, especially when sampling. Instead, the current implementation I suggest that we assume that batch_dim_theta is a multiple of batch_dim_x (i.e. for sampling, we have n chains in theta for each x). In this case we expand the batch dim of x to batch_theta, and match which x goes to which theta. If we are happy with this approach, I'll go ahead and apply this also to the MCMC init_strategy, etc., and make sure this is consistent with other calls.

Remove warning for batched x and default to batched evaluation
Not sure if we want batched evaluation as the default. I think it's easier to do batched evaluation when sample_batched or log_prob_batched is called, and otherwise assume iid (and warn if batch dim >1 as before).

@gmoss13 gmoss13 requested a review from janfb June 27, 2024 16:04
@manuelgloeckler
Copy link
Contributor Author

Great, it looks good. I like that the choice on iid or not can now be made at the set_x method which makes a lot of sense.

I would also opt for your suggested option. The question arises because we squeeze the batch_shape into a single dimension, right? For "PyTorch" broadcasting, one would expect something like (1,batch_x_dim, x_dim) and (batch_theta_dim, betach_x_dim, theta_dim) -> (batch_x_dim, batch_theta_dim), so by squeezing the xs, thetas into 2d one would always get a dimension that is a multiple of batch_x_dim (otherwise it cannot be represented by a fixed size tensor).

For (1,batch_x_dim,x_dim) and (batch_theta_dim, 1, theta_dim), PyTorch broadcasting semantics would compute all combinations. Unfortunately, after squeezing, these distinctions between cases can no longer be fully preserved.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great effort, thanks a lot for tacking this 👏

I do have a couple of comments and questions. Happy to discuss in person if needed.

sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved

x_ = x.repeat_interleave(num_chains, dim=0)

self.potential_fn.set_x(x_, interpret_as_iid=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the =False is hardcoded here. but maybe it will become clear below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense now. I am just wondering, what if x was set already before using set_default_x, and it was set with iid samples. Maybe we should add a warning then? Effectively, sample_batched then overwrites the default x by default. It should all be clear from the API of course, especially because one has to pass a new x here, but for users it might not be clear that they cannot mix iid and batched evaluation. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like adding a warning here. We can check if self._potential_fn already has a x_is_iid set as True, and then raise a warning in this case, that the user is mixing iid and batch evaluation.

Copy link
Contributor

@gmoss13 gmoss13 Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, we should also raise a warning for MCMCPosterior.sample() if x_is_iid was previously set as False.

sbi/inference/posteriors/mcmc_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
sbi/utils/conditional_density_utils.py Outdated Show resolved Hide resolved
sbi/utils/potentialutils.py Outdated Show resolved Hide resolved
sbi/utils/sbiutils.py Show resolved Hide resolved
tests/posterior_nn_test.py Outdated Show resolved Hide resolved
tests/posterior_nn_test.py Outdated Show resolved Hide resolved
@gmoss13
Copy link
Contributor

gmoss13 commented Jul 19, 2024

Great effort, thanks a lot for tacking this 👏

I do have a couple of comments and questions. Happy to discuss in person if needed.

Thanks for the review! I implemented your suggestions.

An additional point - For posterior_based_potential, indeed we should not allow for iid_x, as this is handled by PermutationInvariantNetwork. Instead, we now always treat x batches as not iid. If the user tries to set potential.set_x(x,x_is_iid=True) with a PosteriorBasedPotential, we raise an error stating this. I added a few test cases in embedding_net_test.py::test_embedding_api_with_multiple_trials to test whether batches of x are interpreted correctly when we use a PermutationInvariantNetwork.

@gmoss13 gmoss13 requested a review from janfb July 19, 2024 15:51
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I added just a couple of last questions..

Comment on lines +254 to +260
if not x_o_is_iid:
warn(
"The default `x_o` has `x_is_iid = False`, but you are now sampling "
"with a batch `x` with `x_is_iid = True`. If you want to sample non-iid"
"`x`, please reset `x_is_iid = False` in the potential function.",
stacklevel=2,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when does this happen? Does the user have to explicitly use the posterior.potential_fn object to make it happen?
If not, i.e., if it can happen to user that constructed the posterior using build_posterior (maybe without knowing anything about potentials), then we should add more details to this warning, e.g., "by setting posterior.potential_fn.x_is_iid=False".

Or am I missing something here?

Copy link
Contributor

@gmoss13 gmoss13 Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we support both sample_batched and sample to work when a batch of x is passed, we want to warn the user to use the correct one. That is, if x is batched iid, then the user should use sample, and if x is batched and NOT iid, the user should use sample_batched. I think instead of warning the user to change the definition of x in the potential, it's sufficient to just warn them to use the correct one of sample and sample_batched. But this is already covered in sbi.utils.sbiutils, where we already raise a warning that if the batch dimension of x is greater than 1, make sure to use the correct choice of sample or sample_batched to make sure the x is interpreted correctly. So maybe here we remove the warning altogether?

@@ -321,6 +334,7 @@ def sample(
thin=thin, # type: ignore
warmup_steps=warmup_steps, # type: ignore
vectorized=(method == "slice_np_vectorized"),
interchangeable_chains=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just true for slice_np sampling? Or should we expose it via the kwargs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interchangeable_chains corresponds to whether the batch of x provided the sampler is iid or not. Right now, we only support sample_batched with slice_np_vectorized. In the future, we might want to support doing this with one of the other mcmc samplers we have available, in which case we can expose interchangeable_chains in the kwargs - although I would suggest this can be a future PR?

sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
Comment on lines -104 to -105
theta = ensure_theta_batched(torch.as_tensor(theta)).to(self.device)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this not needed anymore? especially the to(self.device).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, it's still needed. I will add this back

tests/embedding_net_test.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants