Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Generalize SSL functionality to work on other datasets #555

Merged
merged 25 commits into from
Sep 15, 2021
Merged
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
42072a6
extending _get_transforms to accept new datasets
vale-salvatelli Aug 24, 2021
a8ebe13
expand get_cxr_ssl_transform to avoid hidden channel expansion
vale-salvatelli Aug 24, 2021
c2c5fe7
drop_last set as parameter of InnerEyeVisionDataModule
vale-salvatelli Aug 24, 2021
682d2ab
drop_last is now a SSLContainer parameter
vale-salvatelli Aug 24, 2021
68cb45c
Updating Changelog
vale-salvatelli Aug 24, 2021
bdf4ca6
Fix PEP8
vale-salvatelli Aug 24, 2021
fcf27ed
fixing mypy error
vale-salvatelli Aug 25, 2021
bccdb6b
still one fix
vale-salvatelli Aug 25, 2021
68ae373
Merge branch 'main' into vsalva/generalize_ssl
vale-salvatelli Aug 25, 2021
26522c8
Updating to main
vale-salvatelli Aug 25, 2021
68dd10c
generalize function names for readibility
vale-salvatelli Aug 26, 2021
d72a36b
Updating documentation
vale-salvatelli Aug 26, 2021
daeaab1
Updating documentation
vale-salvatelli Aug 26, 2021
6509e4f
removing unexpected changes in amlignore
vale-salvatelli Aug 26, 2021
bc5a81c
Adding test
vale-salvatelli Aug 26, 2021
fc22df9
Adding bits to the test
vale-salvatelli Aug 26, 2021
d74eaf4
committing to switch branch, test_transform pipeline still to be fixed
vale-salvatelli Sep 1, 2021
0cc7893
fixing test
vale-salvatelli Sep 14, 2021
c474713
remove TODO
vale-salvatelli Sep 14, 2021
7cf4459
fixing conlicts
vale-salvatelli Sep 14, 2021
9cda074
fixing flake8
vale-salvatelli Sep 14, 2021
7fa0dbd
fixing flake8 for real
vale-salvatelli Sep 14, 2021
1b978dd
fixing more flake8
vale-salvatelli Sep 14, 2021
7af305c
docstring changed
vale-salvatelli Sep 15, 2021
0c255a5
docstring changed, thanks Mel
vale-salvatelli Sep 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Updating documentation
  • Loading branch information
vale-salvatelli committed Aug 26, 2021
commit daeaab1bddfb27e194d299b120e0e46c2ea0e9d8
13 changes: 9 additions & 4 deletions docs/self_supervised_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,29 @@ with the following available arguments:
* `random_seed`: seed for the run,
* `num_epochs`: number of epochs to train for.

In case you wish to first test your model locally, here some optional arguments that can be useful:
* `local_dataset`: path to local dataset, if passed the azure dataset will be ignored
* `is_debug_model`: if True it will only run on the first batch of each epoch
* `drop_last`: if False (True by default) it will keep the last batch also if incomplete

### Creating your own datamodules:

To use this code with your own data, you will need to:

1. Create a dataset class that reads your new dataset, inheriting from both `VisionDataset`
1. Define your own Lightening Container that inherits from `SSLContainer` as described in the paragraph above.
2. Create a dataset class that reads your new dataset, inheriting from both `VisionDataset`
and `InnerEyeDataClassBaseWithReturnIndex`. See for example how we constructed `RSNAKaggleCXR`
class. WARNING: the first positional argument of your dataset class constructor MUST be the data directory ("root"),
as VisionDataModule expects this in the prepare_data step.
2. Define your own Lightening Container that inherits from `SSLContainer` as described in the paragraph above.
3. In your own container update the `_SSLDataClassMappings` member of the class so that the code knows which data class
to associate to your new dataset name.
3. Create a yaml configuration file that contains the augmentations specific to your dataset. The yaml file will be
4. Create a yaml configuration file that contains the augmentations specific to your dataset. The yaml file will be
consumed by the `create_transforms_from_config` function defined in the
`InnerEye.ML.augmentations.transform_pipeline` module (see next paragraph for more details). Alternatively, overwrite
the `_get_transforms` method. To simplify this step, we have defined a series of standard operations in
`SSL/transforms_utils.py` . You could for example construct a transform pipeline similar to the one created
inside `create_transform_from_config` inside your own method.
4. Update all necessary parameters in the model config (cf. previous paragraph)
5. Update all necessary parameters in the model config (cf. previous paragraph)

Once all these steps are updated, the code in the base SSLContainer class will take care of creating the corresponding
datamodules for SSL training and linear head monitoring.
Expand Down