Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Pad model outputs if they are smaller than the inputs (#681)
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Mar 9, 2022
1 parent 5336a67 commit 8a78ec8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ gets uploaded to AzureML, by skipping all test folders.

### Fixed

- ([#681](https://github.com/microsoft/InnerEye-DeepLearning/pull/681)) Pad model outputs if they are smaller than the inputs.
- ([#683](https://github.com/microsoft/InnerEye-DeepLearning/pull/683)) Fix missing separator error in docs Makefile.
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Fix caching and checkpointing for TCGA CRCk dataset.
- ([#649](https://github.com/microsoft/InnerEye-DeepLearning/pull/649)) Fix for the _convert_to_tensor_if_necessary method so that PIL.Image as well as np.array get converted to torch.Tensor.
Expand Down
8 changes: 8 additions & 0 deletions InnerEye/ML/pipelines/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ def predict_whole_image(self, image_channels: np.ndarray,
locations = patches_batch[tio.LOCATION]
# perform the forward pass
patches_posteriors = self.model(input_tensor).detach()
# pad posteriors if they are smaller than the input
input_shape = input_tensor.shape[-3:]
patches_posteriors_shape = patches_posteriors.shape[-3:]
if input_shape != patches_posteriors_shape:
difference = np.array(input_shape) - np.array(patches_posteriors_shape)
assert not np.any(difference % 2) # the differences in shape are expected to be even
padding = tuple(np.repeat(difference // 2, 2))
patches_posteriors = torch.nn.functional.pad(patches_posteriors, padding)
# collect the predictions over each of the batches
aggregator.add_batch(patches_posteriors, locations)
posteriors = aggregator.get_output_tensor().numpy()
Expand Down

0 comments on commit 8a78ec8

Please sign in to comment.