Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Generalize SSL functionality to work on other datasets #555

Merged
merged 25 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
42072a6
extending _get_transforms to accept new datasets
vale-salvatelli Aug 24, 2021
a8ebe13
expand get_cxr_ssl_transform to avoid hidden channel expansion
vale-salvatelli Aug 24, 2021
c2c5fe7
drop_last set as parameter of InnerEyeVisionDataModule
vale-salvatelli Aug 24, 2021
682d2ab
drop_last is now a SSLContainer parameter
vale-salvatelli Aug 24, 2021
68cb45c
Updating Changelog
vale-salvatelli Aug 24, 2021
bdf4ca6
Fix PEP8
vale-salvatelli Aug 24, 2021
fcf27ed
fixing mypy error
vale-salvatelli Aug 25, 2021
bccdb6b
still one fix
vale-salvatelli Aug 25, 2021
68ae373
Merge branch 'main' into vsalva/generalize_ssl
vale-salvatelli Aug 25, 2021
26522c8
Updating to main
vale-salvatelli Aug 25, 2021
68dd10c
generalize function names for readibility
vale-salvatelli Aug 26, 2021
d72a36b
Updating documentation
vale-salvatelli Aug 26, 2021
daeaab1
Updating documentation
vale-salvatelli Aug 26, 2021
6509e4f
removing unexpected changes in amlignore
vale-salvatelli Aug 26, 2021
bc5a81c
Adding test
vale-salvatelli Aug 26, 2021
fc22df9
Adding bits to the test
vale-salvatelli Aug 26, 2021
d74eaf4
committing to switch branch, test_transform pipeline still to be fixed
vale-salvatelli Sep 1, 2021
0cc7893
fixing test
vale-salvatelli Sep 14, 2021
c474713
remove TODO
vale-salvatelli Sep 14, 2021
7cf4459
fixing conlicts
vale-salvatelli Sep 14, 2021
9cda074
fixing flake8
vale-salvatelli Sep 14, 2021
7fa0dbd
fixing flake8 for real
vale-salvatelli Sep 14, 2021
1b978dd
fixing more flake8
vale-salvatelli Sep 14, 2021
7af305c
docstring changed
vale-salvatelli Sep 15, 2021
0c255a5
docstring changed, thanks Mel
vale-salvatelli Sep 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs that run in AzureML.

### Changed
- ([#531](https://github.com/microsoft/InnerEye-DeepLearning/pull/531)) Updated PL to 1.3.8, torchmetrics and pl-bolts and changed relevant metrics and SSL code API.
- ([#555](https://github.com/microsoft/InnerEye-DeepLearning/pull/555)) Make the SSLContainer compatible with new datasets
- ([#533](https://github.com/microsoft/InnerEye-DeepLearning/pull/533)) Better defaults for inference on ensemble children.
- ([#536](https://github.com/microsoft/InnerEye-DeepLearning/pull/536)) Inference will not run on the validation set by default, this can be turned on
via the `--inference_on_val_set` flag.
Expand Down
8 changes: 5 additions & 3 deletions InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self,
num_workers: int = 6,
batch_size: int = 32,
seed: int = 42,
drop_last: bool = True,
*args: Any, **kwargs: Any) -> None:
"""
Wrapper around VisionDatamodule to load torchvision dataset into a pytorch-lightning module.
Expand All @@ -42,16 +43,17 @@ def __init__(self,
:param val_transforms: transforms to use at validation time
:param data_dir: data directory where to find the data
:param val_split: proportion of training dataset to use for validation
:param num_workers: number of processes for dataloaders.
:param batch_size: batch size for training & validation.
:param num_workers: number of processes for dataloaders
:param batch_size: batch size for training & validation
:param seed: random seed for dataset splitting
:param drop_last: bool, if true it drops the last incomplete batch
"""
data_dir = data_dir if data_dir is not None else os.getcwd()
super().__init__(data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
batch_size=batch_size,
drop_last=True,
drop_last=drop_last,
train_transforms=train_transforms,
val_transforms=val_transforms,
seed=seed,
Expand Down
21 changes: 13 additions & 8 deletions InnerEye/ML/SSL/datamodules_and_datasets/transforms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@
from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform
from yacs.config import CfgNode

from InnerEye.ML.augmentations.transform_pipeline import create_cxr_transforms_from_config
from InnerEye.ML.augmentations.transform_pipeline import create_transforms_from_config


def get_cxr_ssl_transforms(config: CfgNode,
return_two_views_per_sample: bool,
use_training_augmentations_for_validation: bool = False) -> Tuple[Any, Any]:
def get_ssl_transforms_from_config(config: CfgNode,
return_two_views_per_sample: bool,
use_training_augmentations_for_validation: bool = False,
expand_channels: bool = True) -> Tuple[Any, Any]:
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns training and validation transforms for CXR.
Transformations are constructed in the following way:
1. Construct the pipeline of augmentations in create_chest_xray_transform (e.g. resize, flip, affine) as defined
1. Construct the pipeline of augmentations in create_transform_from_config (e.g. resize, flip, affine) as defined
by the config.
2. If we just want to construct the transformation pipeline for a classification model or for the linear evaluator
of the SSL module, return this pipeline.
Expand All @@ -33,10 +34,14 @@ def get_cxr_ssl_transforms(config: CfgNode,
:param use_training_augmentations_for_validation: If True, use augmentation at validation time too.
This is required for SSL validation loss to be meaningful. If False, only apply basic processing step
(no augmentations)
:param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
"""
train_transforms = create_cxr_transforms_from_config(config, apply_augmentations=True)
val_transforms = create_cxr_transforms_from_config(config,
apply_augmentations=use_training_augmentations_for_validation)
train_transforms = create_transforms_from_config(config, apply_augmentations=True,
expand_channels=expand_channels)
val_transforms = create_transforms_from_config(config,
apply_augmentations=use_training_augmentations_for_validation,
expand_channels=expand_channels)
if return_two_views_per_sample:
train_transforms = DualViewTransformWrapper(train_transforms) # type: ignore
val_transforms = DualViewTransformWrapper(val_transforms) # type: ignore
Expand Down
28 changes: 21 additions & 7 deletions InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule, InnerEyeVisionDataModule
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \
InnerEyeCIFARTrainTransform, \
get_cxr_ssl_transforms
get_ssl_transforms_from_config
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
from InnerEye.ML.SSL.encoders import get_encoder_output_dim
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
Expand Down Expand Up @@ -96,6 +96,7 @@ class SSLContainer(LightningContainer):
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4,
doc="Learning rate for linear head training during "
"SSL training.")
drop_last = param.Boolean(default=True, doc="If True drops the last incomplete batch")

def setup(self) -> None:
from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer
Expand Down Expand Up @@ -166,8 +167,8 @@ def create_model(self) -> LightningModule:
f"Found {self.ssl_training_type.value}")
model.hparams.update({'ssl_type': self.ssl_training_type.value,
"num_classes": self.data_module.num_classes})
self.encoder_output_dim = get_encoder_output_dim(model, self.data_module)

self.encoder_output_dim = get_encoder_output_dim(model, self.data_module)
return model

def get_data_module(self) -> InnerEyeDataModuleTypes:
Expand Down Expand Up @@ -209,7 +210,8 @@ def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisio
data_dir=str(datamodule_args.dataset_path),
batch_size=batch_size_per_gpu,
num_workers=self.num_workers,
seed=self.random_seed)
seed=self.random_seed,
drop_last=self.drop_last)
dm.prepare_data()
dm.setup()
return dm
Expand All @@ -232,16 +234,28 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode],
SSLDatasetName.CheXpert.value,
SSLDatasetName.Covid.value]:
assert augmentation_config is not None
train_transforms, val_transforms = get_cxr_ssl_transforms(augmentation_config,
return_two_views_per_sample=is_ssl_encoder_module,
use_training_augmentations_for_validation=is_ssl_encoder_module)
train_transforms, val_transforms = get_ssl_transforms_from_config(
augmentation_config,
return_two_views_per_sample=is_ssl_encoder_module,
use_training_augmentations_for_validation=is_ssl_encoder_module
)
elif dataset_name in [SSLDatasetName.CIFAR10.value, SSLDatasetName.CIFAR100.value]:
train_transforms = \
InnerEyeCIFARTrainTransform(32) if is_ssl_encoder_module else InnerEyeCIFARLinearHeadTransform(32)
val_transforms = \
InnerEyeCIFARTrainTransform(32) if is_ssl_encoder_module else InnerEyeCIFARLinearHeadTransform(32)
elif augmentation_config:
train_transforms, val_transforms = get_ssl_transforms_from_config(
augmentation_config,
return_two_views_per_sample=is_ssl_encoder_module,
use_training_augmentations_for_validation=is_ssl_encoder_module,
expand_channels=False,
)
logging.warning(f"Dataset {dataset_name} unknown. The config will be consumed by "
f"get_ssl_transforms() to create the augmentation pipeline, make sure"
f"the transformations in your configs are compatible. ")
else:
raise ValueError(f"Dataset {dataset_name} unknown.")
raise ValueError(f"Dataset {dataset_name} unknown and no config has been passed.")

return train_transforms, val_transforms

Expand Down
16 changes: 11 additions & 5 deletions InnerEye/ML/augmentations/transform_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,22 @@ def __call__(self, data: ImageData) -> torch.Tensor:
return self.transform_image(data)


def create_cxr_transforms_from_config(config: CfgNode,
apply_augmentations: bool) -> ImageTransformationPipeline:
def create_transforms_from_config(config: CfgNode,
apply_augmentations: bool,
expand_channels: bool = True) -> ImageTransformationPipeline:
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
"""
Defines the image transformations pipeline used in Chest-Xray datasets. Can be used for other types of
images data, type of augmentations to use and strength are expected to be defined in the config.
Defines the image transformations pipeline from a config file. It has been designed for Chest X-Ray
images but it can be used for other types of images data, type of augmentations to use and strength are
expected to be defined in the config. The channel expansion is needed for gray images.
:param config: config yaml file fixing strength and type of augmentation to apply
:param apply_augmentations: if True return transformation pipeline with augmentations. Else,
disable augmentations i.e. only resize and center crop the image.
:param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
"""
transforms: List[Any] = [ExpandChannels()]
transforms: List[Any] = []
if expand_channels:
transforms.append(ExpandChannels())
if apply_augmentations:
if config.augmentation.use_random_affine:
transforms.append(RandomAffine(
Expand Down
6 changes: 3 additions & 3 deletions InnerEye/ML/configs/classification/CovidModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
from InnerEye.ML.SSL.utils import create_ssl_encoder, create_ssl_image_classifier, load_yaml_augmentation_config
from InnerEye.ML.augmentations.transform_pipeline import create_cxr_transforms_from_config
from InnerEye.ML.augmentations.transform_pipeline import create_transforms_from_config
from InnerEye.ML.common import ModelExecutionMode

from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_linear_head_augmentation_cxr
Expand Down Expand Up @@ -137,9 +137,9 @@ def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> Datas
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
config = load_yaml_augmentation_config(path_linear_head_augmentation_cxr)
train_transforms = Compose(
[DicomPreparation(), create_cxr_transforms_from_config(config, apply_augmentations=True)])
[DicomPreparation(), create_transforms_from_config(config, apply_augmentations=True)])
val_transforms = Compose(
[DicomPreparation(), create_cxr_transforms_from_config(config, apply_augmentations=False)])
[DicomPreparation(), create_transforms_from_config(config, apply_augmentations=False)])

return ModelTransformsPerExecutionMode(train=train_transforms,
val=val_transforms,
Expand Down
6 changes: 3 additions & 3 deletions Tests/ML/augmentations/test_transform_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from InnerEye.ML.augmentations.image_transforms import AddGaussianNoise, ElasticTransform, ExpandChannels, RandomGamma
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline, \
create_cxr_transforms_from_config
create_transforms_from_config

from Tests.SSL.test_data_modules import cxr_augmentation_config

Expand Down Expand Up @@ -111,7 +111,7 @@ def test_create_transform_pipeline_from_config() -> None:
"""
Tests that the pipeline returned by create_transform_pipeline_from_config returns the expected transformation.
"""
transformation_pipeline = create_cxr_transforms_from_config(cxr_augmentation_config, apply_augmentations=True)
transformation_pipeline = create_transforms_from_config(cxr_augmentation_config, apply_augmentations=True)
fake_cxr_as_array = np.ones([256, 256]) * 255.
fake_cxr_as_array[100:150, 100:200] = 1
fake_cxr_image = PIL.Image.fromarray(fake_cxr_as_array).convert("L")
Expand Down Expand Up @@ -154,7 +154,7 @@ def test_create_transform_pipeline_from_config() -> None:
assert torch.isclose(expected_transformed, transformed_image).all()

# Test the evaluation pipeline
transformation_pipeline = create_cxr_transforms_from_config(cxr_augmentation_config, apply_augmentations=False)
transformation_pipeline = create_transforms_from_config(cxr_augmentation_config, apply_augmentations=False)
transformed_image = transformation_pipeline(image)
assert isinstance(transformed_image, torch.Tensor)
all_transforms = [ExpandChannels(), Resize(size=256), CenterCrop(size=224)]
Expand Down
10 changes: 5 additions & 5 deletions Tests/SSL/test_data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import RSNAKaggleCXR
from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule, InnerEyeVisionDataModule
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \
InnerEyeCIFARTrainTransform, get_cxr_ssl_transforms
InnerEyeCIFARTrainTransform, get_ssl_transforms_from_config
from InnerEye.ML.SSL.lightning_containers.ssl_container import SSLContainer, SSLDatasetName
from InnerEye.ML.SSL.utils import SSLDataModuleType, load_yaml_augmentation_config
from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_encoder_augmentation_cxr
Expand All @@ -32,8 +32,8 @@ def test_weights_innereye_module() -> None:
"""
Tests if weights in CXR data module are correctly initialized
"""
transforms = get_cxr_ssl_transforms(cxr_augmentation_config,
return_two_views_per_sample=True)
transforms = get_ssl_transforms_from_config(cxr_augmentation_config,
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
return_two_views_per_sample=True)
data_module = InnerEyeVisionDataModule(dataset_cls=RSNAKaggleCXR,
return_index=False,
train_transforms=transforms[0],
Expand Down Expand Up @@ -179,8 +179,8 @@ def test_combined_data_module() -> None:
"""
Tests the behavior of CombinedDataModule
"""
_, val_transform = get_cxr_ssl_transforms(cxr_augmentation_config,
return_two_views_per_sample=False)
_, val_transform = get_ssl_transforms_from_config(cxr_augmentation_config,
return_two_views_per_sample=False)

# Datamodule expected to have 12 training batches - 3 val
long_data_module = InnerEyeVisionDataModule(dataset_cls=RSNAKaggleCXR,
Expand Down
24 changes: 16 additions & 8 deletions docs/self_supervised_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,29 @@ with the following available arguments:
* `random_seed`: seed for the run,
* `num_epochs`: number of epochs to train for.

In case you wish to first test your model locally, here some optional arguments that can be useful:
* `local_dataset`: path to local dataset, if passed the azure dataset will be ignored
* `is_debug_model`: if True it will only run on the first batch of each epoch
* `drop_last`: if False (True by default) it will keep the last batch also if incomplete

### Creating your own datamodules:

To use this code with your own data, you will need to:

1. Create a dataset class that reads your new dataset, inheriting from both `VisionDataset`
1. Define your own Lightening Container that inherits from `SSLContainer` as described in the paragraph above.
2. Create a dataset class that reads your new dataset, inheriting from both `VisionDataset`
and `InnerEyeDataClassBaseWithReturnIndex`. See for example how we constructed `RSNAKaggleCXR`
class. WARNING: the first positional argument of your dataset class constructor MUST be the data directory ("root"),
as VisionDataModule expects this in the prepare_data step.
2. Add a member to the `SSLDatasetName` Enum with your new dataset and update the `_SSLDataClassMappings` member of the
class so that the code knows which data class to associate to your new dataset name.
3. Update the `_get_transforms` methods to add the transform specific to your new dataset. To simplify this step, we
have defined a series of standard transforms parametrized by an augmentation yaml file in `SSL/transforms_utils.py` (
see next paragraph for more details). You could for example construct a transform pipeline similar to the one created
with `get_cxr_ssl_transforms` for our CXR examples.
4. Update all necessary parameters in the model config (cf. previous paragraph)
3. In your own container update the `_SSLDataClassMappings` member of the class so that the code knows which data class
to associate to your new dataset name.
4. Create a yaml configuration file that contains the augmentations specific to your dataset. The yaml file will be
consumed by the `create_transforms_from_config` function defined in the
`InnerEye.ML.augmentations.transform_pipeline` module (see next paragraph for more details). Alternatively, overwrite
the `_get_transforms` method. To simplify this step, we have defined a series of standard operations in
`SSL/transforms_utils.py` . You could for example construct a transform pipeline similar to the one created
inside `create_transform_from_config` inside your own method.
5. Update all necessary parameters in the model config (cf. previous paragraph)

Once all these steps are updated, the code in the base SSLContainer class will take care of creating the corresponding
datamodules for SSL training and linear head monitoring.
Expand Down