Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Replace RadIO with TorchIO for patch-based inference (#666)
Browse files Browse the repository at this point in the history
* Replace RadIO with TorchIO

* Ensure patches are float32 for forward pass

* Update changelog

* Ignore some types to fix mypy errors

* Remove APEX from conda environment in docs example

Co-authored-by: Javier <[email protected]>
  • Loading branch information
fepegar and javier-alvarez committed Feb 23, 2022
1 parent e2ec5cc commit d7e5d8b
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 257 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ created.
## Upcoming

### Added

- ([#666](https://github.com/microsoft/InnerEye-DeepLearning/pull/666)) Replace RadIO with TorchIO for patch-based inference.
- ([#643](https://github.com/microsoft/InnerEye-DeepLearning/pull/643)) Test for recovery of SSL job. Tracks learning rate and train
loss.
- ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) When supplying a "--tag" argument, the AzureML jobs use that value as the display name, to more easily distinguish run.
Expand Down Expand Up @@ -125,7 +127,7 @@ in inference-only runs when using lightning containers.
- ([#638](https://github.com/microsoft/InnerEye-DeepLearning/pull/638)) SimClr cosine LR scheduler was using wrong length information when using with long linear head datasets
- ([#612](https://github.com/microsoft/InnerEye-DeepLearning/pull/612)) SSL online evaluator was not doing distributed training
- ([#652](https://github.com/microsoft/InnerEye-DeepLearning/pull/652)) Run pytest build on Windows after Linux agent version upgrade
- ([#655](https://github.com/microsoft/InnerEye-DeepLearning/pull/655)) Run pytest on Linux again, but with Ubuntu 20.04
- ([#655](https://github.com/microsoft/InnerEye-DeepLearning/pull/655)) Run pytest on Linux again, but with Ubuntu 20.04

### Removed

Expand Down
308 changes: 56 additions & 252 deletions InnerEye/ML/pipelines/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,18 @@
import logging
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Optional, Tuple

import numpy as np
import torch
from radio import CTImagesMaskedBatch
from radio.batchflow import Dataset, action, inbatch_parallel
import torchio as tio

from InnerEye.Common.type_annotations import TupleFloat3
from InnerEye.ML import config
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.lightning_helpers import load_from_checkpoint_and_adjust_for_inference
from InnerEye.ML.lightning_models import SegmentationLightning
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.models.architectures.base_model import BaseSegmentationModel
from InnerEye.ML.utils import image_util, ml_util
from InnerEye.ML.utils.image_util import compute_uncertainty_map_from_posteriors, gaussian_smooth_posteriors, \
posteriors_to_segmentation
Expand Down Expand Up @@ -218,6 +215,24 @@ def create_from_checkpoint(path_to_checkpoint: Path,
assert isinstance(lightning_model, SegmentationLightning)
return InferencePipeline(model=lightning_model, model_config=model_config, pipeline_id=pipeline_id)

def post_process_posteriors(self, posteriors: np.ndarray, mask: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]:
"""
Perform post processing on the computed outputs of the a single pass of the pipelines.
Currently the following operations are performed:
-------------------------------------------------------------------------------------
1) the mask is applied to the posteriors (if required).
2) the final posteriors are used to perform an argmax to generate a multi-label segmentation.
3) extract the largest foreground connected component in the segmentation if required
"""
if mask is not None:
posteriors = image_util.apply_mask_to_posteriors(posteriors=posteriors, mask=mask)

# create segmentation using an argmax over the posterior probabilities
segmentation = image_util.posteriors_to_segmentation(posteriors)

return posteriors, segmentation

@torch.no_grad()
def predict_whole_image(self, image_channels: np.ndarray,
voxel_spacing_mm: TupleFloat3,
mask: np.ndarray = None,
Expand All @@ -238,259 +253,48 @@ def predict_whole_image(self, image_channels: np.ndarray,
if mask is not None:
ml_util.check_size_matches(image_channels, mask, 4, 3, [-1, -2, -3])
self.model.eval()
# create the dataset for the batch
batch_dataset = Dataset(index=[patient_id], batch_class=InferenceBatch)
# setup the pipeline
pipeline = (batch_dataset.p
# define pipeline variables
.init_variables([InferencePipeline.Variables.Model,
InferencePipeline.Variables.ModelConfig,
InferencePipeline.Variables.CropSize,
InferencePipeline.Variables.OutputSize,
InferencePipeline.Variables.OutputImageShape,
InferencePipeline.Variables.Stride])
# update the variables for the batch actions
.update_variable(name=InferencePipeline.Variables.Model, value=self.model)
.update_variable(name=InferencePipeline.Variables.ModelConfig, value=self.model_config)
# perform cascaded batch actions
.load(image_channels=image_channels, mask=mask)
.pre_process()
.predict()
.post_process()
)
# run the batch through the pipeline
logging.info(f"Inference pipeline ({self.pipeline_id}), Predicting patient: {patient_id}")
processed_batch: InferenceBatch = pipeline.next_batch(batch_size=1)
posteriors = processed_batch.get_component(InferenceBatch.Components.Posteriors)
image_util.check_array_range(posteriors, error_prefix="Whole image posteriors")
# prepare pipeline results from the processed batch
return InferencePipeline.Result(
patient_id=patient_id,
segmentation=processed_batch.get_component(InferenceBatch.Components.Segmentation),
posteriors=posteriors,
voxel_spacing_mm=voxel_spacing_mm
)


class InferenceBatch(CTImagesMaskedBatch):
"""
Batch class for IO with the inference pipeline. One instance of a batch will load the image
into the 'images' component of the pipeline, and store the results of the full pass
of the pipeline into the 'segmentation' and 'posteriors' components.
"""

class Components(Enum):
"""
Components associated with the inference batch class
"""
image = tio.ScalarImage(tensor=image_channels)
subject = tio.Subject(image=image)

# the input image channels in Channels x Z x Y x X format.
ImageChannels = 'channels'
# a set of 2D image slices (ie: a 3D image channel), stacked in Z x Y x X format.
Images = 'images'
# a binary mask used to ignore predictions in Z x Y x X format.
Mask = 'mask'
# a numpy.ndarray in Z x Y x X format with class labels for each voxel in the original image.
Segmentation = 'segmentation'
# a numpy.ndarray with the first dimension indexing each class in C x Z x Y x X format
# with each Z x Y x X being the same shape as the Images component, and consisting of
# [0, 1] values representing the model confidence for each voxel.
Posteriors = 'posteriors'

def __init__(self, index: int, *args: Any, **kwargs: Any):
super().__init__(index, *args, **kwargs)
self.components = [x.value for x in InferenceBatch.Components]

@action
def load(self, image_channels: np.ndarray, mask: np.ndarray) -> InferenceBatch:
"""
Load image channels and mask into their respective pipeline components.
"""
self.set_component(component=InferenceBatch.Components.ImageChannels, data=image_channels)
model_config = self.get_configs()
if model_config is None:
raise ValueError("model_config is None")
if model_config.test_crop_size is None:
raise ValueError("model_config.test_crop_size is None")
if model_config.inference_stride_size is None:
raise ValueError("model_config.inference_stride_size is None")

# fetch the image channels from the batch
image_channels = self.get_component(InferenceBatch.Components.ImageChannels)
self.pipeline.set_variable(name=InferencePipeline.Variables.OutputImageShape, value=image_channels[0].shape)
# There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size
# to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged.
image_size = image_channels.shape[1:]
model: BaseSegmentationModel = self.pipeline.get_variable(InferencePipeline.Variables.Model).model
effective_crop, effective_stride = \
model.crop_size_constraints.restrict_crop_size_to_image(image_size,
model_config.test_crop_size,
model_config.inference_stride_size)
self.pipeline.set_variable(name=InferencePipeline.Variables.CropSize, value=effective_crop)
self.pipeline.set_variable(name=InferencePipeline.Variables.Stride, value=effective_stride)
logging.debug(
f"Inference on image size {image_size} will run "
f"with crop size {effective_crop} and stride {effective_stride}")
# In most cases, we will be able to read the output size from the pre-computed values
# via get_output_size. Only if we have a non-standard (smaller) crop size, re-computed the output size.
output_size = model_config.get_output_size(execution_mode=ModelExecutionMode.TEST)
if effective_crop != model_config.test_crop_size:
output_size = model.get_output_shape(input_shape=effective_crop) # type: ignore
self.pipeline.set_variable(name=InferencePipeline.Variables.OutputSize, value=output_size)

if mask is not None:
self.set_component(component=InferenceBatch.Components.Mask, data=mask)

return self

@action
def pre_process(self) -> InferenceBatch:
"""
Prepare the input components of the batch for further processing.
"""
model_config = self.get_configs()

# fetch the image channels from the batch
image_channels = self.get_component(InferenceBatch.Components.ImageChannels)

crop_size = self.pipeline.get_variable(InferencePipeline.Variables.CropSize)
output_size = self.pipeline.get_variable(InferencePipeline.Variables.OutputSize)
image_channels = image_util.pad_images_for_inference(
images=image_channels,
crop_size=crop_size,
output_size=output_size,
padding_mode=model_config.padding_mode
restrict_patch_size = self.model.model.crop_size_constraints.restrict_crop_size_to_image # type: ignore
effective_patch_size, effective_stride = restrict_patch_size(image.spatial_shape, # type: ignore
self.model_config.test_crop_size,
self.model_config.inference_stride_size)

patch_overlap = np.array(effective_patch_size) - np.array(effective_stride)
grid_sampler = tio.inference.GridSampler(
subject,
effective_patch_size,
patch_overlap,
padding_mode=self.model_config.padding_mode.value,
)
batch_size = self.model_config.inference_batch_size
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=batch_size) # type: ignore
aggregator = tio.inference.GridAggregator(grid_sampler)

# update the post-processed components
self.set_component(component=InferenceBatch.Components.ImageChannels, data=image_channels)

return self

@action
def predict(self) -> InferenceBatch:
"""
Perform a forward pass of the model on the provided image, this generates
a set of posterior maps for each class, as well as a segmentation output
stored in the respective 'posteriors' and 'segmentation' components.
"""
model_config = self.get_configs()

# extract patches for each image channel: Num patches x Channels x Z x Y x X
patches = self._extract_patches_for_image_channels()

# split the generated patches into batches and perform forward passes
predictions = []
batch_size = model_config.inference_batch_size

for batch_idx in range(0, len(patches), batch_size):
# slice over the batches to prepare batch
batch = torch.tensor(patches[batch_idx: batch_idx + batch_size, ...]).float()
if model_config.use_gpu:
batch = batch.cuda()
logging.debug(
f"Inference on image size {image.spatial_shape} will run "
f"with crop size {effective_patch_size} and stride {effective_stride}")
for patches_batch in patch_loader:
input_tensor = patches_batch['image'][tio.DATA].float()
if self.model_config.use_gpu:
input_tensor = input_tensor.cuda()
locations = patches_batch[tio.LOCATION]
# perform the forward pass
batch_predictions = self._model_fn(batch).detach().cpu().numpy()
patches_posteriors = self.model(input_tensor).detach()
# collect the predictions over each of the batches
predictions.append(batch_predictions)

# map the batched predictions to the original batch shape
# of shape but with an added class dimension: Num patches x Class x Z x Y x X
predictions = np.concatenate(predictions, axis=0)

# create posterior output for each class with the shape: Class x Z x Y x x. We use float32 as these
# arrays can be big.
output_image_shape = self.pipeline.get_variable(InferencePipeline.Variables.OutputImageShape)
posteriors = np.zeros(shape=[model_config.number_of_classes] + list(output_image_shape), dtype=np.float32)
stride = self.pipeline.get_variable(InferencePipeline.Variables.Stride)

for c in range(len(posteriors)):
# stitch the patches for each posterior class
self.load_from_patches(predictions[:, c, ...], # type: ignore
stride=stride,
scan_shape=output_image_shape,
data_attr=InferenceBatch.Components.Posteriors.value)
# extract computed output from the component so the pipeline buffer can be reused
posteriors[c] = self.get_component(InferenceBatch.Components.Posteriors)

# store the stitched up results for the batch
self.set_component(component=InferenceBatch.Components.Posteriors, data=posteriors)

return self

@action
def post_process(self) -> InferenceBatch:
"""
Perform post processing on the computed outputs of the a single pass of the pipelines.
Currently the following operations are performed:
-------------------------------------------------------------------------------------
1) the mask is applied to the posteriors (if required).
2) the final posteriors are used to perform an argmax to generate a multi-label segmentation.
3) extract the largest foreground connected component in the segmentation if required
"""
mask = self.get_component(InferenceBatch.Components.Mask)
posteriors = self.get_component(InferenceBatch.Components.Posteriors)
if mask is not None:
posteriors = image_util.apply_mask_to_posteriors(posteriors=posteriors, mask=mask)
aggregator.add_batch(patches_posteriors, locations)
posteriors = aggregator.get_output_tensor().numpy()
posteriors, segmentation = self.post_process_posteriors(posteriors, mask=mask)

# create segmentation using an argmax over the posterior probabilities
segmentation = image_util.posteriors_to_segmentation(posteriors)

# update the post-processed posteriors and save the segmentation
self.set_component(component=InferenceBatch.Components.Posteriors, data=posteriors)
self.set_component(component=InferenceBatch.Components.Segmentation, data=segmentation)

return self

def get_configs(self) -> config.SegmentationModelBase:
return self.pipeline.get_variable(InferencePipeline.Variables.ModelConfig)

def get_component(self, component: InferenceBatch.Components) -> np.ndarray:
return getattr(self, component.value) if hasattr(self, component.value) else None

@inbatch_parallel(init='indices', post='_post_custom_components', target='threads')
def set_component(self, batch_idx: int, component: InferenceBatch.Components, data: np.ndarray) \
-> Dict[str, Any]:
logging.debug("Updated data in pipeline component: {}, for batch: {}.".format(component.value, batch_idx))
return {
component.value: {'type': component.value, 'data': data}
}

def _extract_patches_for_image_channels(self) -> np.ndarray:
"""
Extracts deterministically, patches from each image channel
:return: Patches for each image channel in format: Num patches x Channels x Z x Y x X
"""
model_config = self.get_configs()
image_channels = self.get_component(InferenceBatch.Components.ImageChannels)
# There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size
# to always fit into image, and adjust stride accordingly. If test_crop_size is smaller than the
# image, crop and stride will remain unchanged.
crop_size = self.pipeline.get_variable(InferencePipeline.Variables.CropSize)
stride = self.pipeline.get_variable(InferencePipeline.Variables.Stride)
patches = []
for channel_index, channel in enumerate(image_channels):
# set the current image channel component to process
self.set_component(component=InferenceBatch.Components.Images, data=channel)
channel_patches = self.get_patches(patch_shape=crop_size,
stride=stride,
padding=model_config.padding_mode.value,
data_attr=InferenceBatch.Components.Images.value)
logging.debug(
f"Image channel {channel_index}: Tensor with extracted patches has size {channel_patches.shape}")
patches.append(channel_patches)
# reset the images component
self.set_component(component=InferenceBatch.Components.Images, data=[])

return np.stack(patches, axis=1)

def _model_fn(self, patches: torch.Tensor) -> torch.Tensor:
"""
Wrapper function to handle the model forward pass
:param patches: Image patches to be passed to the model in format Patches x Channels x Z x Y x X
:return posteriors: Confidence maps [0,1] for each patch per class
in format: Patches x Channels x Class x Z x Y x X
"""
model = self.pipeline.get_variable(InferencePipeline.Variables.Model)
# Model forward pass returns posteriors
with torch.no_grad():
return model(patches)
image_util.check_array_range(posteriors, error_prefix="Whole image posteriors")
# prepare pipeline results from the processed batch
return InferencePipeline.Result(
patient_id=patient_id,
segmentation=segmentation,
posteriors=posteriors,
voxel_spacing_mm=voxel_spacing_mm
)

0 comments on commit d7e5d8b

Please sign in to comment.