From d7e5d8b5e503438c8964609c0b58c191062152a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Wed, 23 Feb 2022 10:28:11 +0000 Subject: [PATCH] Replace RadIO with TorchIO for patch-based inference (#666) * Replace RadIO with TorchIO * Ensure patches are float32 for forward pass * Update changelog * Ignore some types to fix mypy errors * Remove APEX from conda environment in docs example Co-authored-by: Javier --- CHANGELOG.md | 4 +- InnerEye/ML/pipelines/inference.py | 308 ++++++----------------------- InnerEye/README.md | 4 +- environment.yml | 2 +- 4 files changed, 61 insertions(+), 257 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bedef1db..3619f4292 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ created. ## Upcoming ### Added + +- ([#666](https://github.com/microsoft/InnerEye-DeepLearning/pull/666)) Replace RadIO with TorchIO for patch-based inference. - ([#643](https://github.com/microsoft/InnerEye-DeepLearning/pull/643)) Test for recovery of SSL job. Tracks learning rate and train loss. - ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) When supplying a "--tag" argument, the AzureML jobs use that value as the display name, to more easily distinguish run. @@ -125,7 +127,7 @@ in inference-only runs when using lightning containers. - ([#638](https://github.com/microsoft/InnerEye-DeepLearning/pull/638)) SimClr cosine LR scheduler was using wrong length information when using with long linear head datasets - ([#612](https://github.com/microsoft/InnerEye-DeepLearning/pull/612)) SSL online evaluator was not doing distributed training - ([#652](https://github.com/microsoft/InnerEye-DeepLearning/pull/652)) Run pytest build on Windows after Linux agent version upgrade -- ([#655](https://github.com/microsoft/InnerEye-DeepLearning/pull/655)) Run pytest on Linux again, but with Ubuntu 20.04 +- ([#655](https://github.com/microsoft/InnerEye-DeepLearning/pull/655)) Run pytest on Linux again, but with Ubuntu 20.04 ### Removed diff --git a/InnerEye/ML/pipelines/inference.py b/InnerEye/ML/pipelines/inference.py index f7099a668..bfe3755f2 100644 --- a/InnerEye/ML/pipelines/inference.py +++ b/InnerEye/ML/pipelines/inference.py @@ -7,21 +7,18 @@ import logging from enum import Enum from pathlib import Path -from typing import Any, Dict, Optional +from typing import Optional, Tuple import numpy as np import torch -from radio import CTImagesMaskedBatch -from radio.batchflow import Dataset, action, inbatch_parallel +import torchio as tio from InnerEye.Common.type_annotations import TupleFloat3 from InnerEye.ML import config -from InnerEye.ML.common import ModelExecutionMode from InnerEye.ML.config import SegmentationModelBase from InnerEye.ML.lightning_helpers import load_from_checkpoint_and_adjust_for_inference from InnerEye.ML.lightning_models import SegmentationLightning from InnerEye.ML.model_config_base import ModelConfigBase -from InnerEye.ML.models.architectures.base_model import BaseSegmentationModel from InnerEye.ML.utils import image_util, ml_util from InnerEye.ML.utils.image_util import compute_uncertainty_map_from_posteriors, gaussian_smooth_posteriors, \ posteriors_to_segmentation @@ -218,6 +215,24 @@ def create_from_checkpoint(path_to_checkpoint: Path, assert isinstance(lightning_model, SegmentationLightning) return InferencePipeline(model=lightning_model, model_config=model_config, pipeline_id=pipeline_id) + def post_process_posteriors(self, posteriors: np.ndarray, mask: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]: + """ + Perform post processing on the computed outputs of the a single pass of the pipelines. + Currently the following operations are performed: + ------------------------------------------------------------------------------------- + 1) the mask is applied to the posteriors (if required). + 2) the final posteriors are used to perform an argmax to generate a multi-label segmentation. + 3) extract the largest foreground connected component in the segmentation if required + """ + if mask is not None: + posteriors = image_util.apply_mask_to_posteriors(posteriors=posteriors, mask=mask) + + # create segmentation using an argmax over the posterior probabilities + segmentation = image_util.posteriors_to_segmentation(posteriors) + + return posteriors, segmentation + + @torch.no_grad() def predict_whole_image(self, image_channels: np.ndarray, voxel_spacing_mm: TupleFloat3, mask: np.ndarray = None, @@ -238,259 +253,48 @@ def predict_whole_image(self, image_channels: np.ndarray, if mask is not None: ml_util.check_size_matches(image_channels, mask, 4, 3, [-1, -2, -3]) self.model.eval() - # create the dataset for the batch - batch_dataset = Dataset(index=[patient_id], batch_class=InferenceBatch) - # setup the pipeline - pipeline = (batch_dataset.p - # define pipeline variables - .init_variables([InferencePipeline.Variables.Model, - InferencePipeline.Variables.ModelConfig, - InferencePipeline.Variables.CropSize, - InferencePipeline.Variables.OutputSize, - InferencePipeline.Variables.OutputImageShape, - InferencePipeline.Variables.Stride]) - # update the variables for the batch actions - .update_variable(name=InferencePipeline.Variables.Model, value=self.model) - .update_variable(name=InferencePipeline.Variables.ModelConfig, value=self.model_config) - # perform cascaded batch actions - .load(image_channels=image_channels, mask=mask) - .pre_process() - .predict() - .post_process() - ) - # run the batch through the pipeline - logging.info(f"Inference pipeline ({self.pipeline_id}), Predicting patient: {patient_id}") - processed_batch: InferenceBatch = pipeline.next_batch(batch_size=1) - posteriors = processed_batch.get_component(InferenceBatch.Components.Posteriors) - image_util.check_array_range(posteriors, error_prefix="Whole image posteriors") - # prepare pipeline results from the processed batch - return InferencePipeline.Result( - patient_id=patient_id, - segmentation=processed_batch.get_component(InferenceBatch.Components.Segmentation), - posteriors=posteriors, - voxel_spacing_mm=voxel_spacing_mm - ) - - -class InferenceBatch(CTImagesMaskedBatch): - """ - Batch class for IO with the inference pipeline. One instance of a batch will load the image - into the 'images' component of the pipeline, and store the results of the full pass - of the pipeline into the 'segmentation' and 'posteriors' components. - """ - class Components(Enum): - """ - Components associated with the inference batch class - """ + image = tio.ScalarImage(tensor=image_channels) + subject = tio.Subject(image=image) - # the input image channels in Channels x Z x Y x X format. - ImageChannels = 'channels' - # a set of 2D image slices (ie: a 3D image channel), stacked in Z x Y x X format. - Images = 'images' - # a binary mask used to ignore predictions in Z x Y x X format. - Mask = 'mask' - # a numpy.ndarray in Z x Y x X format with class labels for each voxel in the original image. - Segmentation = 'segmentation' - # a numpy.ndarray with the first dimension indexing each class in C x Z x Y x X format - # with each Z x Y x X being the same shape as the Images component, and consisting of - # [0, 1] values representing the model confidence for each voxel. - Posteriors = 'posteriors' - - def __init__(self, index: int, *args: Any, **kwargs: Any): - super().__init__(index, *args, **kwargs) - self.components = [x.value for x in InferenceBatch.Components] - - @action - def load(self, image_channels: np.ndarray, mask: np.ndarray) -> InferenceBatch: - """ - Load image channels and mask into their respective pipeline components. - """ - self.set_component(component=InferenceBatch.Components.ImageChannels, data=image_channels) - model_config = self.get_configs() - if model_config is None: - raise ValueError("model_config is None") - if model_config.test_crop_size is None: - raise ValueError("model_config.test_crop_size is None") - if model_config.inference_stride_size is None: - raise ValueError("model_config.inference_stride_size is None") - - # fetch the image channels from the batch - image_channels = self.get_component(InferenceBatch.Components.ImageChannels) - self.pipeline.set_variable(name=InferencePipeline.Variables.OutputImageShape, value=image_channels[0].shape) # There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size # to always fit into image. If test_crop_size is smaller than the image, crop will remain unchanged. - image_size = image_channels.shape[1:] - model: BaseSegmentationModel = self.pipeline.get_variable(InferencePipeline.Variables.Model).model - effective_crop, effective_stride = \ - model.crop_size_constraints.restrict_crop_size_to_image(image_size, - model_config.test_crop_size, - model_config.inference_stride_size) - self.pipeline.set_variable(name=InferencePipeline.Variables.CropSize, value=effective_crop) - self.pipeline.set_variable(name=InferencePipeline.Variables.Stride, value=effective_stride) - logging.debug( - f"Inference on image size {image_size} will run " - f"with crop size {effective_crop} and stride {effective_stride}") - # In most cases, we will be able to read the output size from the pre-computed values - # via get_output_size. Only if we have a non-standard (smaller) crop size, re-computed the output size. - output_size = model_config.get_output_size(execution_mode=ModelExecutionMode.TEST) - if effective_crop != model_config.test_crop_size: - output_size = model.get_output_shape(input_shape=effective_crop) # type: ignore - self.pipeline.set_variable(name=InferencePipeline.Variables.OutputSize, value=output_size) - - if mask is not None: - self.set_component(component=InferenceBatch.Components.Mask, data=mask) - - return self - - @action - def pre_process(self) -> InferenceBatch: - """ - Prepare the input components of the batch for further processing. - """ - model_config = self.get_configs() - - # fetch the image channels from the batch - image_channels = self.get_component(InferenceBatch.Components.ImageChannels) - - crop_size = self.pipeline.get_variable(InferencePipeline.Variables.CropSize) - output_size = self.pipeline.get_variable(InferencePipeline.Variables.OutputSize) - image_channels = image_util.pad_images_for_inference( - images=image_channels, - crop_size=crop_size, - output_size=output_size, - padding_mode=model_config.padding_mode + restrict_patch_size = self.model.model.crop_size_constraints.restrict_crop_size_to_image # type: ignore + effective_patch_size, effective_stride = restrict_patch_size(image.spatial_shape, # type: ignore + self.model_config.test_crop_size, + self.model_config.inference_stride_size) + + patch_overlap = np.array(effective_patch_size) - np.array(effective_stride) + grid_sampler = tio.inference.GridSampler( + subject, + effective_patch_size, + patch_overlap, + padding_mode=self.model_config.padding_mode.value, ) + batch_size = self.model_config.inference_batch_size + patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=batch_size) # type: ignore + aggregator = tio.inference.GridAggregator(grid_sampler) - # update the post-processed components - self.set_component(component=InferenceBatch.Components.ImageChannels, data=image_channels) - - return self - - @action - def predict(self) -> InferenceBatch: - """ - Perform a forward pass of the model on the provided image, this generates - a set of posterior maps for each class, as well as a segmentation output - stored in the respective 'posteriors' and 'segmentation' components. - """ - model_config = self.get_configs() - - # extract patches for each image channel: Num patches x Channels x Z x Y x X - patches = self._extract_patches_for_image_channels() - - # split the generated patches into batches and perform forward passes - predictions = [] - batch_size = model_config.inference_batch_size - - for batch_idx in range(0, len(patches), batch_size): - # slice over the batches to prepare batch - batch = torch.tensor(patches[batch_idx: batch_idx + batch_size, ...]).float() - if model_config.use_gpu: - batch = batch.cuda() + logging.debug( + f"Inference on image size {image.spatial_shape} will run " + f"with crop size {effective_patch_size} and stride {effective_stride}") + for patches_batch in patch_loader: + input_tensor = patches_batch['image'][tio.DATA].float() + if self.model_config.use_gpu: + input_tensor = input_tensor.cuda() + locations = patches_batch[tio.LOCATION] # perform the forward pass - batch_predictions = self._model_fn(batch).detach().cpu().numpy() + patches_posteriors = self.model(input_tensor).detach() # collect the predictions over each of the batches - predictions.append(batch_predictions) - - # map the batched predictions to the original batch shape - # of shape but with an added class dimension: Num patches x Class x Z x Y x X - predictions = np.concatenate(predictions, axis=0) - - # create posterior output for each class with the shape: Class x Z x Y x x. We use float32 as these - # arrays can be big. - output_image_shape = self.pipeline.get_variable(InferencePipeline.Variables.OutputImageShape) - posteriors = np.zeros(shape=[model_config.number_of_classes] + list(output_image_shape), dtype=np.float32) - stride = self.pipeline.get_variable(InferencePipeline.Variables.Stride) - - for c in range(len(posteriors)): - # stitch the patches for each posterior class - self.load_from_patches(predictions[:, c, ...], # type: ignore - stride=stride, - scan_shape=output_image_shape, - data_attr=InferenceBatch.Components.Posteriors.value) - # extract computed output from the component so the pipeline buffer can be reused - posteriors[c] = self.get_component(InferenceBatch.Components.Posteriors) - - # store the stitched up results for the batch - self.set_component(component=InferenceBatch.Components.Posteriors, data=posteriors) - - return self - - @action - def post_process(self) -> InferenceBatch: - """ - Perform post processing on the computed outputs of the a single pass of the pipelines. - Currently the following operations are performed: - ------------------------------------------------------------------------------------- - 1) the mask is applied to the posteriors (if required). - 2) the final posteriors are used to perform an argmax to generate a multi-label segmentation. - 3) extract the largest foreground connected component in the segmentation if required - """ - mask = self.get_component(InferenceBatch.Components.Mask) - posteriors = self.get_component(InferenceBatch.Components.Posteriors) - if mask is not None: - posteriors = image_util.apply_mask_to_posteriors(posteriors=posteriors, mask=mask) + aggregator.add_batch(patches_posteriors, locations) + posteriors = aggregator.get_output_tensor().numpy() + posteriors, segmentation = self.post_process_posteriors(posteriors, mask=mask) - # create segmentation using an argmax over the posterior probabilities - segmentation = image_util.posteriors_to_segmentation(posteriors) - - # update the post-processed posteriors and save the segmentation - self.set_component(component=InferenceBatch.Components.Posteriors, data=posteriors) - self.set_component(component=InferenceBatch.Components.Segmentation, data=segmentation) - - return self - - def get_configs(self) -> config.SegmentationModelBase: - return self.pipeline.get_variable(InferencePipeline.Variables.ModelConfig) - - def get_component(self, component: InferenceBatch.Components) -> np.ndarray: - return getattr(self, component.value) if hasattr(self, component.value) else None - - @inbatch_parallel(init='indices', post='_post_custom_components', target='threads') - def set_component(self, batch_idx: int, component: InferenceBatch.Components, data: np.ndarray) \ - -> Dict[str, Any]: - logging.debug("Updated data in pipeline component: {}, for batch: {}.".format(component.value, batch_idx)) - return { - component.value: {'type': component.value, 'data': data} - } - - def _extract_patches_for_image_channels(self) -> np.ndarray: - """ - Extracts deterministically, patches from each image channel - :return: Patches for each image channel in format: Num patches x Channels x Z x Y x X - """ - model_config = self.get_configs() - image_channels = self.get_component(InferenceBatch.Components.ImageChannels) - # There may be cases where the test image is smaller than the test_crop_size. Adjust crop_size - # to always fit into image, and adjust stride accordingly. If test_crop_size is smaller than the - # image, crop and stride will remain unchanged. - crop_size = self.pipeline.get_variable(InferencePipeline.Variables.CropSize) - stride = self.pipeline.get_variable(InferencePipeline.Variables.Stride) - patches = [] - for channel_index, channel in enumerate(image_channels): - # set the current image channel component to process - self.set_component(component=InferenceBatch.Components.Images, data=channel) - channel_patches = self.get_patches(patch_shape=crop_size, - stride=stride, - padding=model_config.padding_mode.value, - data_attr=InferenceBatch.Components.Images.value) - logging.debug( - f"Image channel {channel_index}: Tensor with extracted patches has size {channel_patches.shape}") - patches.append(channel_patches) - # reset the images component - self.set_component(component=InferenceBatch.Components.Images, data=[]) - - return np.stack(patches, axis=1) - - def _model_fn(self, patches: torch.Tensor) -> torch.Tensor: - """ - Wrapper function to handle the model forward pass - :param patches: Image patches to be passed to the model in format Patches x Channels x Z x Y x X - :return posteriors: Confidence maps [0,1] for each patch per class - in format: Patches x Channels x Class x Z x Y x X - """ - model = self.pipeline.get_variable(InferencePipeline.Variables.Model) - # Model forward pass returns posteriors - with torch.no_grad(): - return model(patches) + image_util.check_array_range(posteriors, error_prefix="Whole image posteriors") + # prepare pipeline results from the processed batch + return InferencePipeline.Result( + patient_id=patient_id, + segmentation=segmentation, + posteriors=posteriors, + voxel_spacing_mm=voxel_spacing_mm + ) diff --git a/InnerEye/README.md b/InnerEye/README.md index 0b6736b44..72a39009e 100644 --- a/InnerEye/README.md +++ b/InnerEye/README.md @@ -5,7 +5,7 @@ * You need to have a Conda installation on your machine. * Create a Conda environment file `environment.yml` in your source code with this contents: -``` +```yaml name: MyEnv channels: - defaults @@ -15,8 +15,6 @@ dependencies: - python=3.7.3 - pytorch=1.3.0 - pip: - - git+https://github.com/analysiscenter/radio.git@6d53e25#egg=radio - - git+https://github.com/ptrblck/apex.git@4ad9b3b#egg=apex - innereye ``` diff --git a/environment.yml b/environment.yml index bd4456ed4..1c334c21b 100644 --- a/environment.yml +++ b/environment.yml @@ -11,7 +11,6 @@ dependencies: - python-blosc=1.7.0 - torchvision=0.11.1 - pip: - - git+https://github.com/analysiscenter/radio.git@6d53e25#egg=radio - azure-mgmt-resource==12.1.0 - azure-mgmt-datafactory==1.1.0 - azure-storage-blob==12.6.0 @@ -70,6 +69,7 @@ dependencies: - tabulate==0.8.7 - tensorboard==2.3.0 - tensorboardX==2.1 + - torchio==0.18.73 - torchmetrics==0.6.0 - umap-learn==0.5.2 - yacs==0.1.8