-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Encoder/decoder integration #667
Conversation
Codecov Report
@@ Coverage Diff @@
## master #667 +/- ##
==========================================
+ Coverage 79.09% 87.19% +8.10%
==========================================
Files 24 25 +1
Lines 3506 3695 +189
==========================================
+ Hits 2773 3222 +449
+ Misses 733 473 -260
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
37daece
to
6e1a14a
Compare
bliss/encoder.py
Outdated
else: | ||
lensed_galaxy_bools = ( | ||
torch.rand_like(lensed_galaxy_probs) > 0.5 | ||
).float() * is_on_array.unsqueeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Black auto formatting isn't so great for long lines. It's often clearer to create a temporary/intermediate variable to reduce the length of long lines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yashpatel5400 the code looks really good. My biggest suggestion is on the if
statements scattered in the decoder
related to the lens_params
which I think will make the code harder to read/maintain in the long run. In some cases they can also prevent code optimization.
Do you think there is a way we can get rid of them (I was imagining a separate LensTileDecoder
but maybe this is unfeasible)? Let's discuss over slack or in-person if you are around west hall when you have a chance.
The other comments are pretty minor and related to black formatting mostly
bliss/encoder.py
Outdated
if deterministic: | ||
lensed_galaxy_bools = ( | ||
lensed_galaxy_probs > 0.5 | ||
).float() * is_on_array.unsqueeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to split these lines into multiple statements (one per line)
bliss/encoder.py
Outdated
lensed_galaxy_bools = ( | ||
torch.rand_like(lensed_galaxy_probs) > 0.5 | ||
).float() * is_on_array.unsqueeze(-1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
bliss/encoder.py
Outdated
lensed_galaxy_bools *= ( | ||
galaxy_bools # currently only support lensing where galaxy is present | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to put the comment in a separate line to avoid black formatting
bliss/encoder.py
Outdated
tile_samples.update( | ||
{ | ||
"lensed_galaxy_bools": lensed_galaxy_bools, | ||
"lensed_galaxy_probs": lensed_galaxy_probs, | ||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might consider saving some lines by doing e.g. tile_samples['lensed_galaxy_bools'] = lensed_galaxy_bools
instead
n_bands: int = 1, | ||
psf_image_file: Optional[str] = None, | ||
psf_params_file: Optional[str] = None, | ||
psf_slen: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this really optional?
bliss/models/prior.py
Outdated
catalog_params.update( | ||
{ | ||
"lensed_galaxy_bools": lensed_galaxy_bools, | ||
"lens_params": lens_params, | ||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
save some lines by catalog_params['lens_params'] = lens_params
bliss/models/prior.py
Outdated
galaxy_bools, star_bools, lensed_galaxy_bools = self._sample_n_galaxies_and_stars( | ||
is_on_array | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would make a new function _sample_n_lens
or something like that which takes is_on_array
and galaxy_bool
and returns lensed_galaxy_bools
@@ -190,7 +190,7 @@ def encode(self, image_ptiles: Tensor, tile_locs: Tensor) -> Tensor: | |||
ms=max_sources, | |||
) | |||
|
|||
def sample(self, image_ptiles: Tensor, tile_locs: Tensor): | |||
def sample(self, image_ptiles: Tensor, tile_locs: Tensor, deterministic: Optional[bool]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did you add this deterministic
flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just to match the other GalaxyEncoder (since we support both in encoder.py)
bliss/models/decoder.py
Outdated
@@ -195,6 +195,15 @@ def _render_ptiles(self, tile_catalog: TileCatalog) -> Tensor: | |||
tile_catalog["star_fluxes"], "b nth ntw s band -> (b nth ntw) s band" | |||
) | |||
|
|||
if "lens_params" in tile_catalog: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should talk about avoiding this if
statements scattered throughout the code, in general tihs makes the code harder to read and maintain. I was imagining separate lens_decoder
just like we have for the stars would be cleaner to maintain. Ofc it would have to be conditioned on the galaxy_bools or something. Could you tell me if this is possible ?
bliss/models/decoder.py
Outdated
@@ -580,7 +503,15 @@ def _render_single_galaxies(self, galaxy_params, galaxy_bools): | |||
|
|||
# forward only galaxies that are on! | |||
# no background | |||
gal_on = self.galaxy_decoder(z[b == 1]) | |||
if lens_params is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another if
statement on lens_params
, we should talk about whether we can remove this...
6e1a14a
to
fc8cf79
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
ea6ad4d
to
231e3c4
Compare
alright, think I've addressed all the comments from above! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
bliss/models/prior.py
Outdated
# currently only support lensing where galaxy is present | ||
lensed_galaxy_bools = uniform < self.prob_lensed_galaxy | ||
lensed_galaxy_bools = ( | ||
lensed_galaxy_bools * galaxy_bools * is_on_array.unsqueeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use *=
here so it all fits on one line?
(Same as the previous PR, but cleaned up and now with all the tests passing)