Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encoder/decoder integration #667

Merged
merged 4 commits into from
Aug 15, 2022
Merged

Conversation

yashpatel5400
Copy link
Contributor

(Same as the previous PR, but cleaned up and now with all the tests passing)

@codecov
Copy link

codecov bot commented Aug 8, 2022

Codecov Report

Merging #667 (231e3c4) into master (434cb60) will increase coverage by 8.10%.
The diff coverage is 91.95%.

@@            Coverage Diff             @@
##           master     #667      +/-   ##
==========================================
+ Coverage   79.09%   87.19%   +8.10%     
==========================================
  Files          24       25       +1     
  Lines        3506     3695     +189     
==========================================
+ Hits         2773     3222     +449     
+ Misses        733      473     -260     
Flag Coverage Δ
unittests 87.19% <91.95%> (+8.10%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
bliss/catalog.py 84.12% <ø> (ø)
bliss/encoder.py 84.54% <37.50%> (-13.16%) ⬇️
bliss/models/galsim_decoder.py 89.34% <94.87%> (+1.63%) ⬆️
bliss/models/psf_decoder.py 95.40% <95.40%> (ø)
bliss/models/decoder.py 97.87% <100.00%> (+0.81%) ⬆️
bliss/models/galsim_encoder.py 21.59% <100.00%> (+21.59%) ⬆️
bliss/models/prior.py 90.60% <100.00%> (+3.21%) ⬆️
bliss/models/lensing_binary_encoder.py 76.37% <0.00%> (+76.37%) ⬆️
bliss/models/lens_encoder.py 81.00% <0.00%> (+81.00%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

bliss/encoder.py Outdated
else:
lensed_galaxy_bools = (
torch.rand_like(lensed_galaxy_probs) > 0.5
).float() * is_on_array.unsqueeze(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Black auto formatting isn't so great for long lines. It's often clearer to create a temporary/intermediate variable to reduce the length of long lines

Copy link
Collaborator

@ismael-mendoza ismael-mendoza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yashpatel5400 the code looks really good. My biggest suggestion is on the if statements scattered in the decoder related to the lens_params which I think will make the code harder to read/maintain in the long run. In some cases they can also prevent code optimization.

Do you think there is a way we can get rid of them (I was imagining a separate LensTileDecoder but maybe this is unfeasible)? Let's discuss over slack or in-person if you are around west hall when you have a chance.

The other comments are pretty minor and related to black formatting mostly

bliss/encoder.py Outdated
Comment on lines 196 to 199
if deterministic:
lensed_galaxy_bools = (
lensed_galaxy_probs > 0.5
).float() * is_on_array.unsqueeze(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to split these lines into multiple statements (one per line)

bliss/encoder.py Outdated
Comment on lines 201 to 202
lensed_galaxy_bools = (
torch.rand_like(lensed_galaxy_probs) > 0.5
).float() * is_on_array.unsqueeze(-1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

bliss/encoder.py Outdated
Comment on lines 205 to 207
lensed_galaxy_bools *= (
galaxy_bools # currently only support lensing where galaxy is present
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to put the comment in a separate line to avoid black formatting

bliss/encoder.py Outdated
Comment on lines 209 to 214
tile_samples.update(
{
"lensed_galaxy_bools": lensed_galaxy_bools,
"lensed_galaxy_probs": lensed_galaxy_probs,
}
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might consider saving some lines by doing e.g. tile_samples['lensed_galaxy_bools'] = lensed_galaxy_bools instead

n_bands: int = 1,
psf_image_file: Optional[str] = None,
psf_params_file: Optional[str] = None,
psf_slen: Optional[int] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this really optional?

Comment on lines 180 to 185
catalog_params.update(
{
"lensed_galaxy_bools": lensed_galaxy_bools,
"lens_params": lens_params,
}
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save some lines by catalog_params['lens_params'] = lens_params

Comment on lines 156 to 158
galaxy_bools, star_bools, lensed_galaxy_bools = self._sample_n_galaxies_and_stars(
is_on_array
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make a new function _sample_n_lens or something like that which takes is_on_array and galaxy_bool and returns lensed_galaxy_bools

@@ -190,7 +190,7 @@ def encode(self, image_ptiles: Tensor, tile_locs: Tensor) -> Tensor:
ms=max_sources,
)

def sample(self, image_ptiles: Tensor, tile_locs: Tensor):
def sample(self, image_ptiles: Tensor, tile_locs: Tensor, deterministic: Optional[bool]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you add this deterministic flag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just to match the other GalaxyEncoder (since we support both in encoder.py)

@@ -195,6 +195,15 @@ def _render_ptiles(self, tile_catalog: TileCatalog) -> Tensor:
tile_catalog["star_fluxes"], "b nth ntw s band -> (b nth ntw) s band"
)

if "lens_params" in tile_catalog:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should talk about avoiding this if statements scattered throughout the code, in general tihs makes the code harder to read and maintain. I was imagining separate lens_decoder just like we have for the stars would be cleaner to maintain. Ofc it would have to be conditioned on the galaxy_bools or something. Could you tell me if this is possible ?

@@ -580,7 +503,15 @@ def _render_single_galaxies(self, galaxy_params, galaxy_bools):

# forward only galaxies that are on!
# no background
gal_on = self.galaxy_decoder(z[b == 1])
if lens_params is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another if statement on lens_params, we should talk about whether we can remove this...

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@yashpatel5400
Copy link
Contributor Author

alright, think I've addressed all the comments from above!

@jeff-regier jeff-regier self-requested a review August 15, 2022 01:25
Copy link
Contributor

@jeff-regier jeff-regier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

# currently only support lensing where galaxy is present
lensed_galaxy_bools = uniform < self.prob_lensed_galaxy
lensed_galaxy_bools = (
lensed_galaxy_bools * galaxy_bools * is_on_array.unsqueeze(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use *= here so it all fits on one line?

@jeff-regier jeff-regier merged commit 832893c into prob-ml:master Aug 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants