-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: batched sampling for MCMC #1176
base: main
Are you sure you want to change the base?
Conversation
…rs' into amortizedsample
…from-different-posteriors' into amortizedsample
… reshapes in rejection
This reverts commit 17c5343.
Co-authored-by: Jan <[email protected]>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1176 +/- ##
==========================================
- Coverage 84.55% 75.65% -8.91%
==========================================
Files 96 96
Lines 7603 7685 +82
==========================================
- Hits 6429 5814 -615
- Misses 1174 1871 +697
Flags with carried forward coverage won't be shown. Click here to find out more.
|
I've made some progress now towards this PR, and would like some feedback before I continue.
Given
|
Great, it looks good. I like that the choice on iid or not can now be made at the I would also opt for your suggested option. The question arises because we squeeze the batch_shape into a single dimension, right? For "PyTorch" broadcasting, one would expect something like (1,batch_x_dim, x_dim) and (batch_theta_dim, betach_x_dim, theta_dim) -> (batch_x_dim, batch_theta_dim), so by squeezing the xs, thetas into 2d one would always get a dimension that is a multiple of batch_x_dim (otherwise it cannot be represented by a fixed size tensor). For (1,batch_x_dim,x_dim) and (batch_theta_dim, 1, theta_dim), PyTorch broadcasting semantics would compute all combinations. Unfortunately, after squeezing, these distinctions between cases can no longer be fully preserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great effort, thanks a lot for tacking this 👏
I do have a couple of comments and questions. Happy to discuss in person if needed.
|
||
x_ = x.repeat_interleave(num_chains, dim=0) | ||
|
||
self.potential_fn.set_x(x_, interpret_as_iid=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why the =False
is hardcoded here. but maybe it will become clear below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense now. I am just wondering, what if x
was set already before using set_default_x
, and it was set with iid samples. Maybe we should add a warning then? Effectively, sample_batched
then overwrites the default x by default. It should all be clear from the API of course, especially because one has to pass a new x
here, but for users it might not be clear that they cannot mix iid and batched evaluation. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like adding a warning here. We can check if self._potential_fn
already has a x_is_iid
set as True, and then raise a warning in this case, that the user is mixing iid and batch evaluation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, we should also raise a warning for MCMCPosterior.sample()
if x_is_iid
was previously set as False.
Thanks for the review! I implemented your suggestions. An additional point - For |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I added just a couple of last questions..
if not x_o_is_iid: | ||
warn( | ||
"The default `x_o` has `x_is_iid = False`, but you are now sampling " | ||
"with a batch `x` with `x_is_iid = True`. If you want to sample non-iid" | ||
"`x`, please reset `x_is_iid = False` in the potential function.", | ||
stacklevel=2, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when does this happen? Does the user have to explicitly use the posterior.potential_fn
object to make it happen?
If not, i.e., if it can happen to user that constructed the posterior using build_posterior
(maybe without knowing anything about potentials
), then we should add more details to this warning, e.g., "by setting posterior.potential_fn.x_is_iid=False
".
Or am I missing something here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we support both sample_batched
and sample
to work when a batch of x
is passed, we want to warn the user to use the correct one. That is, if x
is batched iid, then the user should use sample
, and if x
is batched and NOT iid, the user should use sample_batched
. I think instead of warning the user to change the definition of x
in the potential, it's sufficient to just warn them to use the correct one of sample
and sample_batched
. But this is already covered in sbi.utils.sbiutils
, where we already raise a warning that if the batch dimension of x
is greater than 1, make sure to use the correct choice of sample
or sample_batched
to make sure the x
is interpreted correctly. So maybe here we remove the warning altogether?
@@ -321,6 +334,7 @@ def sample( | |||
thin=thin, # type: ignore | |||
warmup_steps=warmup_steps, # type: ignore | |||
vectorized=(method == "slice_np_vectorized"), | |||
interchangeable_chains=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just true for slice_np
sampling? Or should we expose it via the kwargs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interchangeable_chains
corresponds to whether the batch of x
provided the sampler is iid or not. Right now, we only support sample_batched
with slice_np_vectorized
. In the future, we might want to support doing this with one of the other mcmc samplers we have available, in which case we can expose interchangeable_chains
in the kwargs
- although I would suggest this can be a future PR?
theta = ensure_theta_batched(torch.as_tensor(theta)).to(self.device) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this not needed anymore? especially the to(self.device)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right, it's still needed. I will add this back
What does this implement/fix? Explain your changes
This pull request aims to implement the
sample_batched
method for MCMC.Current problem
BasePotential
can either "allow_iid" or not. Hence, each batch dimension will be interpreted as IID samples.allow_iid
with a mutable attribute (or optional input argument)interpret_as_iid
.The current implementation will let you sample the correct shape, BUT will output the wrong solution. This is because the potential function will broadcast, repeat and finally sum up the first dimension which is incorrect.