Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: batched sampling for MCMC #1176

Merged
merged 81 commits into from
Jul 30, 2024
Merged
Changes from 1 commit
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
17c5343
Base estimator class
manuelgloeckler Apr 29, 2024
705e9df
intermediate commit
michaeldeistler May 3, 2024
07b53cd
make autoreload work
michaeldeistler May 3, 2024
dd02e22
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
663185b
fixes current bug!
manuelgloeckler May 7, 2024
df8899a
Added tests
manuelgloeckler May 7, 2024
aa82aab
batched_rejection_sampling
manuelgloeckler May 7, 2024
00cdade
intermediate commit
michaeldeistler May 3, 2024
cb8e4d8
make autoreload work
michaeldeistler May 3, 2024
d64557f
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
f16622d
Merge branch 'amortizedsample' of https://github.com/sbi-dev/sbi into…
manuelgloeckler May 7, 2024
07084e2
Merge branch '990-add-sample_batched-and-log_prob_batched-to-posterio…
manuelgloeckler May 7, 2024
e54a2fb
Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-…
manuelgloeckler May 7, 2024
52d0e7e
Merge branch '1154-density-estimator-batched-sample-mixes-up-samples-…
manuelgloeckler May 7, 2024
cd808d5
sample works, try log_prob_batched
manuelgloeckler May 7, 2024
f542224
log_prob_batched works
manuelgloeckler May 7, 2024
48a1a28
abstract method implement for other methods
manuelgloeckler May 7, 2024
5a37330
temp fix mcmcposterior
manuelgloeckler May 7, 2024
2b23e42
meh for general use i.e. in the restriction prior we have to add some…
manuelgloeckler May 7, 2024
6362051
... test class
manuelgloeckler May 7, 2024
294609d
Revert "Base estimator class"
manuelgloeckler May 8, 2024
99abbb1
removing previous change
manuelgloeckler May 8, 2024
ef9e99c
removing some artifacts
manuelgloeckler May 8, 2024
5eb1007
revert wierd change
manuelgloeckler May 8, 2024
82127ab
docs and tests
manuelgloeckler May 8, 2024
41617a8
MCMC sample_batched works but not log_prob batched
manuelgloeckler May 14, 2024
82951db
adding some docs
manuelgloeckler May 14, 2024
c5fac1d
batch_log_prob for MCMC requires at best changes for potential -> rem…
manuelgloeckler May 14, 2024
0d82422
intermediate commit
michaeldeistler May 3, 2024
57cfde3
make autoreload work
michaeldeistler May 3, 2024
de5d647
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
f8b6604
intermediate commit
michaeldeistler May 3, 2024
1dcf882
make autoreload work
michaeldeistler May 3, 2024
5a31970
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
871c4de
Base estimator class
manuelgloeckler Apr 29, 2024
f87d6b6
Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-…
manuelgloeckler May 7, 2024
dbd0109
fixes current bug!
manuelgloeckler May 7, 2024
264b6c4
Added tests
manuelgloeckler May 7, 2024
339b57b
batched_rejection_sampling
manuelgloeckler May 7, 2024
676c271
sample works, try log_prob_batched
manuelgloeckler May 7, 2024
7a8a84d
log_prob_batched works
manuelgloeckler May 7, 2024
5daab92
abstract method implement for other methods
manuelgloeckler May 7, 2024
40897a0
temp fix mcmcposterior
manuelgloeckler May 7, 2024
a2b7e32
meh for general use i.e. in the restriction prior we have to add some…
manuelgloeckler May 7, 2024
cb4d8ae
... test class
manuelgloeckler May 7, 2024
ab9b1e1
Revert "Base estimator class"
manuelgloeckler May 8, 2024
d2b1a62
removing previous change
manuelgloeckler May 8, 2024
a0c0c97
removing some artifacts
manuelgloeckler May 8, 2024
8fc5a46
revert wierd change
manuelgloeckler May 8, 2024
18c7d36
docs and tests
manuelgloeckler May 8, 2024
6ad6cb7
MCMC sample_batched works but not log_prob batched
manuelgloeckler May 14, 2024
03c10f3
adding some docs
manuelgloeckler May 14, 2024
24c4821
batch_log_prob for MCMC requires at best changes for potential -> rem…
manuelgloeckler May 14, 2024
1769d6e
Merge branch 'amortizedsample' of https://github.com/sbi-dev/sbi into…
manuelgloeckler Jun 11, 2024
a445a6c
Fixing bug from rebase...
manuelgloeckler Jun 11, 2024
86767a1
tracking all acceptance rates
manuelgloeckler Jun 11, 2024
9502af3
Comment on NFlows
manuelgloeckler Jun 11, 2024
c80e6ff
Also testing SNRE batched sampling, Need to test ensemble implementation
manuelgloeckler Jun 11, 2024
7aac84c
fig bug
manuelgloeckler Jun 11, 2024
7d4eb55
Ensemble sample_batched is working (with tests)
manuelgloeckler Jun 11, 2024
f53e1ec
GPU compatibility
manuelgloeckler Jun 11, 2024
2dc6ebd
restriction priopr requires float as output of accept_reject
manuelgloeckler Jun 11, 2024
7dfda13
Adding a few comments
manuelgloeckler Jun 11, 2024
89b6e8f
2d sample_shape tests
manuelgloeckler Jun 11, 2024
35dcf40
Merge branch 'main' into amortizedsample
janfb Jun 13, 2024
93ca374
Apply suggestions from code review
manuelgloeckler Jun 14, 2024
86f3531
Adding comment about squeeze
manuelgloeckler Jun 14, 2024
c55e6e4
Formating new mcmc branch
manuelgloeckler Jun 18, 2024
c18958a
mcmc sample batched for likelihood estimator
gmoss13 Jun 25, 2024
9ff2ce8
batch sampling for snpe,snre
gmoss13 Jun 27, 2024
05da5e3
Merge branch 'main' into amortized_sample_mcmc
gmoss13 Jun 27, 2024
f759e23
ruff fixes after merge
gmoss13 Jun 27, 2024
94732aa
pytest not catching xfail
gmoss13 Jun 27, 2024
69f459e
mcmc_posterior sample_batched disappeared in merge
gmoss13 Jun 27, 2024
ce24632
move mcmc chain shape handling to mcmcposterior away from potentials
gmoss13 Jul 11, 2024
25f7e2c
batched init strategies for mcmc
gmoss13 Jul 12, 2024
f98bf4d
Merge branch 'main' into amortized_sample_mcmc
gmoss13 Jul 15, 2024
4524853
update raio_based_potential for new RatioEstimator class
gmoss13 Jul 15, 2024
2c7fc0e
mcmc sample shape out fix and process_x utils
gmoss13 Jul 15, 2024
fd11a72
suggestions from jan
gmoss13 Jul 19, 2024
813ee75
warning on batched x
gmoss13 Jul 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding comment about squeeze
  • Loading branch information
manuelgloeckler committed Jun 14, 2024
commit 86f3531d2ee584d45afb0f099ce311e82a79678a
9 changes: 4 additions & 5 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ def sample(
raise NameError(f"The sampling method {method} is not implemented!")

samples = self.theta_transform.inv(transformed_samples)
# NOTE: Currently MCMCPosteriors will require a single dimension for the
# parameter dimension. With recent ConditionalDensity(Ratio) estimators, we
# can have multiple dimensions for the parameter dimension.
samples = samples.reshape((*sample_shape, -1)) # type: ignore

return samples
Expand Down Expand Up @@ -591,11 +594,7 @@ def _slice_np_mcmc(

def multi_obs_potential(params):
# Params are of shape (num_chains * num_obs, event).
# We now reshape them to (num_chains, num_obs, event).
# params = np.reshape(params, (num_chains, num_obs, -1))

# `all_potentials` is of shape (num_chains, num_obs).
all_potentials = potential_function(params)
all_potentials = potential_function(params) # Shape: (num_chains, num_obs)
return all_potentials.flatten()

posterior_sampler = SliceSamplerMultiChain(
Expand Down
Loading