Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Generalize SSL functionality to work on other datasets #555

Merged
merged 25 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
42072a6
extending _get_transforms to accept new datasets
vale-salvatelli Aug 24, 2021
a8ebe13
expand get_cxr_ssl_transform to avoid hidden channel expansion
vale-salvatelli Aug 24, 2021
c2c5fe7
drop_last set as parameter of InnerEyeVisionDataModule
vale-salvatelli Aug 24, 2021
682d2ab
drop_last is now a SSLContainer parameter
vale-salvatelli Aug 24, 2021
68cb45c
Updating Changelog
vale-salvatelli Aug 24, 2021
bdf4ca6
Fix PEP8
vale-salvatelli Aug 24, 2021
fcf27ed
fixing mypy error
vale-salvatelli Aug 25, 2021
bccdb6b
still one fix
vale-salvatelli Aug 25, 2021
68ae373
Merge branch 'main' into vsalva/generalize_ssl
vale-salvatelli Aug 25, 2021
26522c8
Updating to main
vale-salvatelli Aug 25, 2021
68dd10c
generalize function names for readibility
vale-salvatelli Aug 26, 2021
d72a36b
Updating documentation
vale-salvatelli Aug 26, 2021
daeaab1
Updating documentation
vale-salvatelli Aug 26, 2021
6509e4f
removing unexpected changes in amlignore
vale-salvatelli Aug 26, 2021
bc5a81c
Adding test
vale-salvatelli Aug 26, 2021
fc22df9
Adding bits to the test
vale-salvatelli Aug 26, 2021
d74eaf4
committing to switch branch, test_transform pipeline still to be fixed
vale-salvatelli Sep 1, 2021
0cc7893
fixing test
vale-salvatelli Sep 14, 2021
c474713
remove TODO
vale-salvatelli Sep 14, 2021
7cf4459
fixing conlicts
vale-salvatelli Sep 14, 2021
9cda074
fixing flake8
vale-salvatelli Sep 14, 2021
7fa0dbd
fixing flake8 for real
vale-salvatelli Sep 14, 2021
1b978dd
fixing more flake8
vale-salvatelli Sep 14, 2021
7af305c
docstring changed
vale-salvatelli Sep 15, 2021
0c255a5
docstring changed, thanks Mel
vale-salvatelli Sep 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
committing to switch branch, test_transform pipeline still to be fixed
  • Loading branch information
vale-salvatelli committed Sep 1, 2021
commit d74eaf47b5b72c7f0268dd6241b6994116454f66
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_ssl_transforms_from_config(config: CfgNode,

:param config: configuration defining which augmentations to apply as well as their intensities.
:param return_two_views_per_sample: if True the resulting transforms will return two versions of each sample they
are called on. If False, simply return one transformed version of the sample.
are called on. If False, simply return one transformed version of the sample centered and cropped.
:param use_training_augmentations_for_validation: If True, use augmentation at validation time too.
This is required for SSL validation loss to be meaningful. If False, only apply basic processing step
(no augmentations)
Expand Down
8 changes: 4 additions & 4 deletions InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisio
"""
Returns torch lightning data module for encoder or linear head

:param is_ssl_encoder_module: whether to return the data module for SSL training or for linear heard. If true,
:param is_ssl_encoder_module: whether to return the data module for SSL training or for linear head. If true,
:return transforms with two views per sample (batch like (img_v1, img_v2, label)). If False, return only one
view per sample but also return the index of the sample in the dataset (to make sure we don't use twice the same
batch in one training epoch (batch like (index, img_v1, label), as classifier dataloader expected to be shorter
Expand Down Expand Up @@ -225,8 +225,8 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode],
examples.
:param dataset_name: name of the dataset, value has to be in SSLDatasetName, determines which transformation
pipeline to return.
:param is_ssl_encoder_module: if True the transformation pipeline will yield two version of the image it is
applied on. If False, return only one transformation.
:param is_ssl_encoder_module: if True the transformation pipeline will yield two versions of the image it is
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
applied on and it applies the same transformations for validation. If False, return only one transformation.
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
:return: training transformation pipeline and validation transformation pipeline.
"""
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value,
Expand All @@ -252,7 +252,7 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode],
expand_channels=False,
)
logging.warning(f"Dataset {dataset_name} unknown. The config will be consumed by "
f"get_ssl_transforms() to create the augmentation pipeline, make sure"
f"get_ssl_transforms() to create the augmentation pipeline, make sure "
f"the transformations in your configs are compatible. ")
else:
raise ValueError(f"Dataset {dataset_name} unknown and no config has been passed.")
Expand Down
2 changes: 2 additions & 0 deletions Tests/ML/augmentations/test_transform_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def test_create_transform_pipeline_from_config(expand_channels: bool) -> None:
all_transforms.insert(0, ExpandChannels())
else:
fake_3d_array = np.stack([fake_cxr_as_array for i in range(3)])
# TODO this is raising an error - understands what shapes/values you need in here
fake_cxr_image = PIL.Image.fromarray(fake_3d_array).convert("RGB")

np.random.seed(3)
Expand All @@ -159,6 +160,7 @@ def test_create_transform_pipeline_from_config(expand_channels: bool) -> None:
assert torch.isclose(expected_transformed, transformed_image).all()

# Test the evaluation pipeline
# TODO why this is not parametrized?
transformation_pipeline = create_transforms_from_config(cxr_augmentation_config, apply_augmentations=False,
expand_channels=expand_channels)
transformed_image = transformation_pipeline(image)
Expand Down