Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Generalize SSL functionality to work on other datasets (#555)
Browse files Browse the repository at this point in the history
This PR contains some changes needed to make the SSLContainer compatible with new datasets and allow a user to run by simply creating a new augmentation config or defining a child class

* _get_transforms has been changed to accept new datasets without the need to touch the class
* get_cxr_ssl_transform has been changed to avoid the hidden channel expansion and make that optional. It has been also renamed to get_ssl_transform because it has nothing specific to cxr
* drop_last is now set as parameter of the InnerEyeVisionDataModule and the SSLContainer - that means it can be changed when initializing a new SSLContainer
* documentation about bringing your own SSL model has been updated
  • Loading branch information
vale-salvatelli committed Sep 15, 2021
1 parent 5b7d571 commit 521c004
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 66 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs that run in AzureML.

### Changed
- ([#531](https://github.com/microsoft/InnerEye-DeepLearning/pull/531)) Updated PL to 1.3.8, torchmetrics and pl-bolts and changed relevant metrics and SSL code API.
- ([#555](https://github.com/microsoft/InnerEye-DeepLearning/pull/555)) Make the SSLContainer compatible with new datasets
- ([#533](https://github.com/microsoft/InnerEye-DeepLearning/pull/533)) Better defaults for inference on ensemble children.
- ([#536](https://github.com/microsoft/InnerEye-DeepLearning/pull/536)) Inference will not run on the validation set by default, this can be turned on
via the `--inference_on_val_set` flag.
Expand Down
8 changes: 5 additions & 3 deletions InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self,
num_workers: int = 6,
batch_size: int = 32,
seed: int = 42,
drop_last: bool = True,
*args: Any, **kwargs: Any) -> None:
"""
Wrapper around VisionDatamodule to load torchvision dataset into a pytorch-lightning module.
Expand All @@ -42,16 +43,17 @@ def __init__(self,
:param val_transforms: transforms to use at validation time
:param data_dir: data directory where to find the data
:param val_split: proportion of training dataset to use for validation
:param num_workers: number of processes for dataloaders.
:param batch_size: batch size for training & validation.
:param num_workers: number of processes for dataloaders
:param batch_size: batch size for training & validation
:param seed: random seed for dataset splitting
:param drop_last: bool, if true it drops the last incomplete batch
"""
data_dir = data_dir if data_dir is not None else os.getcwd()
super().__init__(data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
batch_size=batch_size,
drop_last=True,
drop_last=drop_last,
train_transforms=train_transforms,
val_transforms=val_transforms,
seed=seed,
Expand Down
23 changes: 14 additions & 9 deletions InnerEye/ML/SSL/datamodules_and_datasets/transforms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@
from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform
from yacs.config import CfgNode

from InnerEye.ML.augmentations.transform_pipeline import create_cxr_transforms_from_config
from InnerEye.ML.augmentations.transform_pipeline import create_transforms_from_config


def get_cxr_ssl_transforms(config: CfgNode,
return_two_views_per_sample: bool,
use_training_augmentations_for_validation: bool = False) -> Tuple[Any, Any]:
def get_ssl_transforms_from_config(config: CfgNode,
return_two_views_per_sample: bool,
use_training_augmentations_for_validation: bool = False,
expand_channels: bool = True) -> Tuple[Any, Any]:
"""
Returns training and validation transforms for CXR.
Transformations are constructed in the following way:
1. Construct the pipeline of augmentations in create_chest_xray_transform (e.g. resize, flip, affine) as defined
1. Construct the pipeline of augmentations in create_transform_from_config (e.g. resize, flip, affine) as defined
by the config.
2. If we just want to construct the transformation pipeline for a classification model or for the linear evaluator
of the SSL module, return this pipeline.
Expand All @@ -29,14 +30,18 @@ def get_cxr_ssl_transforms(config: CfgNode,
:param config: configuration defining which augmentations to apply as well as their intensities.
:param return_two_views_per_sample: if True the resulting transforms will return two versions of each sample they
are called on. If False, simply return one transformed version of the sample.
are called on. If False, simply return one transformed version of the sample centered and cropped.
:param use_training_augmentations_for_validation: If True, use augmentation at validation time too.
This is required for SSL validation loss to be meaningful. If False, only apply basic processing step
(no augmentations)
:param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
"""
train_transforms = create_cxr_transforms_from_config(config, apply_augmentations=True)
val_transforms = create_cxr_transforms_from_config(config,
apply_augmentations=use_training_augmentations_for_validation)
train_transforms = create_transforms_from_config(config, apply_augmentations=True,
expand_channels=expand_channels)
val_transforms = create_transforms_from_config(config,
apply_augmentations=use_training_augmentations_for_validation,
expand_channels=expand_channels)
if return_two_views_per_sample:
train_transforms = DualViewTransformWrapper(train_transforms) # type: ignore
val_transforms = DualViewTransformWrapper(val_transforms) # type: ignore
Expand Down
36 changes: 26 additions & 10 deletions InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule, InnerEyeVisionDataModule
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \
InnerEyeCIFARTrainTransform, \
get_cxr_ssl_transforms
get_ssl_transforms_from_config
from InnerEye.ML.SSL.encoders import get_encoder_output_dim
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
Expand Down Expand Up @@ -96,6 +96,7 @@ class SSLContainer(LightningContainer):
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4,
doc="Learning rate for linear head training during "
"SSL training.")
drop_last = param.Boolean(default=True, doc="If True drops the last incomplete batch")

def setup(self) -> None:
from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer
Expand Down Expand Up @@ -166,8 +167,8 @@ def create_model(self) -> LightningModule:
f"Found {self.ssl_training_type.value}")
model.hparams.update({'ssl_type': self.ssl_training_type.value,
"num_classes": self.data_module.num_classes})
self.encoder_output_dim = get_encoder_output_dim(model, self.data_module)

self.encoder_output_dim = get_encoder_output_dim(model, self.data_module)
return model

def get_data_module(self) -> InnerEyeDataModuleTypes:
Expand All @@ -186,7 +187,7 @@ def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisio
"""
Returns torch lightning data module for encoder or linear head
:param is_ssl_encoder_module: whether to return the data module for SSL training or for linear heard. If true,
:param is_ssl_encoder_module: whether to return the data module for SSL training or for linear head. If true,
:return transforms with two views per sample (batch like (img_v1, img_v2, label)). If False, return only one
view per sample but also return the index of the sample in the dataset (to make sure we don't use twice the same
batch in one training epoch (batch like (index, img_v1, label), as classifier dataloader expected to be shorter
Expand All @@ -209,7 +210,8 @@ def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisio
data_dir=str(datamodule_args.dataset_path),
batch_size=batch_size_per_gpu,
num_workers=self.num_workers,
seed=self.random_seed)
seed=self.random_seed,
drop_last=self.drop_last)
dm.prepare_data()
dm.setup()
return dm
Expand All @@ -223,25 +225,39 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode],
examples.
:param dataset_name: name of the dataset, value has to be in SSLDatasetName, determines which transformation
pipeline to return.
:param is_ssl_encoder_module: if True the transformation pipeline will yield two version of the image it is
applied on. If False, return only one transformation.
:param is_ssl_encoder_module: if True the transformation pipeline will yield two versions of the image it is
applied on and it applies the training transformations also at validation time. Note that if your transformation
does not contain any randomness, the pipeline will return two identical copies. If False, it will return only one
transformation.
:return: training transformation pipeline and validation transformation pipeline.
"""
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value,
SSLDatasetName.NIHCXR.value,
SSLDatasetName.CheXpert.value,
SSLDatasetName.Covid.value]:
assert augmentation_config is not None
train_transforms, val_transforms = get_cxr_ssl_transforms(augmentation_config,
return_two_views_per_sample=is_ssl_encoder_module,
use_training_augmentations_for_validation=is_ssl_encoder_module)
train_transforms, val_transforms = get_ssl_transforms_from_config(
augmentation_config,
return_two_views_per_sample=is_ssl_encoder_module,
use_training_augmentations_for_validation=is_ssl_encoder_module
)
elif dataset_name in [SSLDatasetName.CIFAR10.value, SSLDatasetName.CIFAR100.value]:
train_transforms = \
InnerEyeCIFARTrainTransform(32) if is_ssl_encoder_module else InnerEyeCIFARLinearHeadTransform(32)
val_transforms = \
InnerEyeCIFARTrainTransform(32) if is_ssl_encoder_module else InnerEyeCIFARLinearHeadTransform(32)
elif augmentation_config:
train_transforms, val_transforms = get_ssl_transforms_from_config(
augmentation_config,
return_two_views_per_sample=is_ssl_encoder_module,
use_training_augmentations_for_validation=is_ssl_encoder_module,
expand_channels=False,
)
logging.warning(f"Dataset {dataset_name} unknown. The config will be consumed by "
f"get_ssl_transforms() to create the augmentation pipeline, make sure "
f"the transformations in your configs are compatible. ")
else:
raise ValueError(f"Dataset {dataset_name} unknown.")
raise ValueError(f"Dataset {dataset_name} unknown and no config has been passed.")

return train_transforms, val_transforms

Expand Down
16 changes: 11 additions & 5 deletions InnerEye/ML/augmentations/transform_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,22 @@ def __call__(self, data: ImageData) -> torch.Tensor:
return self.transform_image(data)


def create_cxr_transforms_from_config(config: CfgNode,
apply_augmentations: bool) -> ImageTransformationPipeline:
def create_transforms_from_config(config: CfgNode,
apply_augmentations: bool,
expand_channels: bool = True) -> ImageTransformationPipeline:
"""
Defines the image transformations pipeline used in Chest-Xray datasets. Can be used for other types of
images data, type of augmentations to use and strength are expected to be defined in the config.
Defines the image transformations pipeline from a config file. It has been designed for Chest X-Ray
images but it can be used for other types of images data, type of augmentations to use and strength are
expected to be defined in the config. The channel expansion is needed for gray images.
:param config: config yaml file fixing strength and type of augmentation to apply
:param apply_augmentations: if True return transformation pipeline with augmentations. Else,
disable augmentations i.e. only resize and center crop the image.
:param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
"""
transforms: List[Any] = [ExpandChannels()]
transforms: List[Any] = []
if expand_channels:
transforms.append(ExpandChannels())
if apply_augmentations:
if config.augmentation.use_random_affine:
transforms.append(RandomAffine(
Expand Down
7 changes: 4 additions & 3 deletions InnerEye/ML/configs/classification/CovidModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
from InnerEye.ML.SSL.utils import create_ssl_image_classifier, load_yaml_augmentation_config
from InnerEye.ML.augmentations.transform_pipeline import create_cxr_transforms_from_config
from InnerEye.ML.augmentations.transform_pipeline import create_transforms_from_config

from InnerEye.ML.common import ModelExecutionMode

from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_linear_head_augmentation_cxr
Expand Down Expand Up @@ -137,9 +138,9 @@ def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> Datas
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
config = load_yaml_augmentation_config(path_linear_head_augmentation_cxr)
train_transforms = Compose(
[DicomPreparation(), create_cxr_transforms_from_config(config, apply_augmentations=True)])
[DicomPreparation(), create_transforms_from_config(config, apply_augmentations=True)])
val_transforms = Compose(
[DicomPreparation(), create_cxr_transforms_from_config(config, apply_augmentations=False)])
[DicomPreparation(), create_transforms_from_config(config, apply_augmentations=False)])

return ModelTransformsPerExecutionMode(train=train_transforms,
val=val_transforms,
Expand Down
52 changes: 30 additions & 22 deletions Tests/ML/augmentations/test_transform_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import PIL
import pytest
import torch
from torchvision.transforms import CenterCrop, ColorJitter, RandomAffine, RandomErasing, RandomHorizontalFlip, \
RandomResizedCrop, Resize, ToTensor
from torchvision.transforms import (CenterCrop, ColorJitter, RandomAffine, RandomErasing, RandomHorizontalFlip,
RandomResizedCrop, Resize, ToTensor)
from torchvision.transforms.functional import to_tensor

from InnerEye.ML.augmentations.image_transforms import AddGaussianNoise, ElasticTransform, ExpandChannels, RandomGamma
from InnerEye.ML.augmentations.image_transforms import (AddGaussianNoise, ElasticTransform,
ExpandChannels, RandomGamma)
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline, \
create_cxr_transforms_from_config
create_transforms_from_config

from Tests.SSL.test_data_modules import cxr_augmentation_config

Expand All @@ -31,7 +32,6 @@
test_4d_scan_as_tensor = torch.ones([5, 4, *image_size]) * 255.
test_4d_scan_as_tensor[..., 10:15, 10:20] = 1


@pytest.mark.parametrize("use_different_transformation_per_channel", [True, False])
def test_torchvision_on_various_input(use_different_transformation_per_channel: bool) -> None:
"""
Expand Down Expand Up @@ -107,17 +107,16 @@ def test_custom_tf_on_various_input(use_different_transformation_per_channel: bo
assert torch.isclose(transformed[0, 0], transformed[1, 1]).all() != use_different_transformation_per_channel


def test_create_transform_pipeline_from_config() -> None:
@pytest.mark.parametrize("expand_channels", [True, False])
def test_create_transform_pipeline_from_config(expand_channels: bool) -> None:
"""
Tests that the pipeline returned by create_transform_pipeline_from_config returns the expected transformation.
"""
transformation_pipeline = create_cxr_transforms_from_config(cxr_augmentation_config, apply_augmentations=True)
transformation_pipeline = create_transforms_from_config(cxr_augmentation_config, apply_augmentations=True,
expand_channels=expand_channels)
fake_cxr_as_array = np.ones([256, 256]) * 255.
fake_cxr_as_array[100:150, 100:200] = 1
fake_cxr_image = PIL.Image.fromarray(fake_cxr_as_array).convert("L")

all_transforms = [ExpandChannels(),
RandomAffine(degrees=180, translate=(0, 0), shear=40),
all_transforms = [RandomAffine(degrees=180, translate=(0, 0), shear=40),
RandomResizedCrop(scale=(0.4, 1.0), size=256),
RandomHorizontalFlip(p=0.5),
RandomGamma(scale=(0.5, 1.5)),
Expand All @@ -128,23 +127,28 @@ def test_create_transform_pipeline_from_config() -> None:
AddGaussianNoise(std=0.05, p_apply=0.5)
]

if expand_channels:
all_transforms.insert(0, ExpandChannels())
# expand channels is used for single-channel input images
fake_image = PIL.Image.fromarray(fake_cxr_as_array).convert("L")
# In the pipeline the image is converted to tensor before applying the transformations. Do the same here.
image = ToTensor()(fake_image).reshape([1, 1, 256, 256])
else:
fake_3d_array = np.dstack([fake_cxr_as_array, fake_cxr_as_array, fake_cxr_as_array])
fake_image = PIL.Image.fromarray(fake_3d_array.astype(np.uint8)).convert("RGB")
# In the pipeline the image is converted to tensor before applying the transformations. Do the same here.
image = ToTensor()(fake_image).reshape([1, 3, 256, 256])

np.random.seed(3)
torch.manual_seed(3)
random.seed(3)

transformed_image = transformation_pipeline(fake_cxr_image)
transformed_image = transformation_pipeline(fake_image)
assert isinstance(transformed_image, torch.Tensor)
# Expected pipeline
image = np.ones([256, 256]) * 255.
image[100:150, 100:200] = 1
image = PIL.Image.fromarray(image).convert("L")
# In the pipeline the image is converted to tensor before applying the transformations. Do the same here.
image = ToTensor()(image).reshape([1, 1, 256, 256])

# Expected pipeline
np.random.seed(3)
torch.manual_seed(3)
random.seed(3)

expected_transformed = image
for t in all_transforms:
expected_transformed = t(expected_transformed)
Expand All @@ -154,10 +158,14 @@ def test_create_transform_pipeline_from_config() -> None:
assert torch.isclose(expected_transformed, transformed_image).all()

# Test the evaluation pipeline
transformation_pipeline = create_cxr_transforms_from_config(cxr_augmentation_config, apply_augmentations=False)
transformation_pipeline = create_transforms_from_config(cxr_augmentation_config, apply_augmentations=False,
expand_channels=expand_channels)
transformed_image = transformation_pipeline(image)
assert isinstance(transformed_image, torch.Tensor)
all_transforms = [ExpandChannels(), Resize(size=256), CenterCrop(size=224)]
all_transforms = [Resize(size=256), CenterCrop(size=224)]
if expand_channels:
all_transforms.insert(0, ExpandChannels())

expected_transformed = image
for t in all_transforms:
expected_transformed = t(expected_transformed)
Expand Down
Loading

0 comments on commit 521c004

Please sign in to comment.