Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Add self-supervised learning capabilities to InnerEye #440

Merged
merged 403 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
403 commits
Select commit Hold shift + click to select a range
5a3a385
final model
ant0nsc Apr 1, 2021
8d2bfdb
flake
ant0nsc Apr 1, 2021
78f4e21
Big refractoring: remove redundancy in dataset classes, use VisionDat…
Apr 1, 2021
ca98534
Merge branch 'melanibe/support-for-multiple-datasets' of https://gith…
Apr 1, 2021
f7e04f5
Update test and simplify
Apr 1, 2021
e39008a
Update test and simplify
Apr 1, 2021
4b1019c
Add possibility to have a custom dataloader
Apr 6, 2021
2750cdc
Clean up ssl_config
Apr 6, 2021
944f413
Adding support for image classifiers
Apr 6, 2021
0765238
update to fastmri master
ant0nsc Apr 7, 2021
15fe1ab
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Apr 7, 2021
d564119
test fixes
ant0nsc Apr 7, 2021
d14865b
Add image classifier
Apr 7, 2021
bb3fd2e
Clean up test
Apr 7, 2021
64d4c2e
Add load ssl weights from somewhere else
Apr 7, 2021
048f982
Clean up test
Apr 7, 2021
87267f9
import fixes
ant0nsc Apr 7, 2021
7ea565a
Merge branch 'antonsc/byol' of https://github.com/microsoft/InnerEye-…
Apr 7, 2021
0ed57a9
Make augmenation wrapper more generic
Apr 8, 2021
d471cd7
Renaming
Apr 8, 2021
b295cfd
Add args to config
Apr 8, 2021
9e8a4f2
Make checkpoint handling different
Apr 8, 2021
48c7179
Default value
Apr 8, 2021
727fead
Make it versatile enough to use with InnerEyeContainer
Apr 8, 2021
8d6cb08
Make it more versatile
Apr 8, 2021
d3988a1
Make it more versatile
Apr 8, 2021
a975992
Fix
Apr 8, 2021
eef0677
Update the readme a little
Apr 8, 2021
dad87c8
Better
Apr 8, 2021
53a9ef9
Update readme
Apr 8, 2021
2063383
test fixes
ant0nsc Apr 9, 2021
58a6dc6
fix import problems
ant0nsc Apr 9, 2021
193e9c9
fix checkpoint test
ant0nsc Apr 9, 2021
a4910e2
Rename to local_paths
Apr 12, 2021
d71c705
Add details for more advanced usage of SSL than just examples
Apr 12, 2021
ffc43ce
Update tests and make them check more parameters
Apr 12, 2021
d1cb491
Clean up
Apr 12, 2021
3231e45
test fix
ant0nsc Apr 12, 2021
88e29fa
Fix test_recover_training_mean_teacher_model
Shruthi42 Apr 12, 2021
448550d
Flake8
Shruthi42 Apr 12, 2021
b6a3575
test fix
ant0nsc Apr 12, 2021
101d42d
test fix
ant0nsc Apr 12, 2021
3cd1b17
Merge branch 'antonsc/byol' of https://github.com/microsoft/InnerEye-…
ant0nsc Apr 12, 2021
ede2829
Remove main() from patch_sampling.py
Shruthi42 Apr 12, 2021
987e83d
mypy
Shruthi42 Apr 12, 2021
9fad817
Merge branch 'antonsc/byol' of https://github.com/microsoft/InnerEye-…
Shruthi42 Apr 12, 2021
c497568
Fix logging
Apr 12, 2021
4faefdc
Fix
Apr 12, 2021
76e5f02
docu
ant0nsc Apr 12, 2021
5dae881
ensure correct seeding in tests
ant0nsc Apr 12, 2021
b744abb
fix test_invalid_trainer_args
ant0nsc Apr 13, 2021
60f5e32
fix downloading problems in AzureML
ant0nsc Apr 13, 2021
2ae170a
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Apr 13, 2021
dc01b61
adding reports for container models
ant0nsc Apr 13, 2021
59e2ca1
fixing test loss values
ant0nsc Apr 13, 2021
8758973
fix bug with running the unit tests in AzureML
ant0nsc Apr 13, 2021
c514ca0
docu, flake
ant0nsc Apr 13, 2021
ab9b4c2
mypy and flake
ant0nsc Apr 13, 2021
7988af9
downgrading mypy
ant0nsc Apr 13, 2021
4e5b7f6
test fixes
ant0nsc Apr 14, 2021
f70884f
test fixes, reduce logging noise
ant0nsc Apr 14, 2021
069ca6c
mypy and flake
ant0nsc Apr 14, 2021
731483f
simplify mypy runner
ant0nsc Apr 14, 2021
1f5afc0
remove comments
ant0nsc Apr 14, 2021
a362e70
Update InnerEye/ML/model_training.py
ant0nsc Apr 14, 2021
7f17e83
Update docs/bring_your_own_model.md
ant0nsc Apr 14, 2021
a8024d1
PR comments
ant0nsc Apr 14, 2021
ed94083
Merge branch 'antonsc/byol' of https://github.com/microsoft/InnerEye-…
ant0nsc Apr 14, 2021
b6a7d01
fixes
ant0nsc Apr 15, 2021
de15468
updated docu and design as per PR feedback.
ant0nsc Apr 15, 2021
5656489
docu and mypy
ant0nsc Apr 15, 2021
efd9231
update doc, add report test
ant0nsc Apr 16, 2021
2f4c36a
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Apr 16, 2021
268b265
HelloContainer model
ant0nsc Apr 16, 2021
8649681
HelloContainer data
ant0nsc Apr 16, 2021
422c790
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Apr 16, 2021
7e1e516
HelloWorld running
ant0nsc Apr 16, 2021
be0948b
docu
ant0nsc Apr 16, 2021
79795e0
fixes
ant0nsc Apr 16, 2021
4273754
Merge branch 'main' of https://github.com/microsoft/InnerEye-DeepLear…
Apr 16, 2021
d2e1b05
Merge branch 'antonsc/byol' of https://github.com/microsoft/InnerEye-…
Apr 16, 2021
e8583e9
Flake8
Apr 16, 2021
2e52d58
Add auto-restart
Apr 19, 2021
be04f44
Change handling of checkpoints and clean-up
Apr 19, 2021
c939279
Save last k recovery checkpoints
Apr 19, 2021
5ae4f28
Log epoch for keeping last ckpt
Apr 19, 2021
8bebc7c
Keeping k last checkpoints
Apr 19, 2021
2c8afdb
Add possibility to recover from particular checkpoint
Apr 19, 2021
bfc3975
Update tests
Apr 19, 2021
096fdcd
Check k recovery
Apr 19, 2021
020a8c1
Re-add skipif
Apr 19, 2021
aea1b0b
Correct pick up of recovery runs and add test
Apr 19, 2021
9ce3e54
Correct pick up of recovery runs and add test
Apr 19, 2021
0657b61
Remove all start epochs
Apr 19, 2021
0c22b53
Remove all start epochs
Apr 19, 2021
422a859
Spimplify run recovery logic
Apr 19, 2021
9dcfef3
Fix it
Apr 19, 2021
60a4114
Merge branch 'main' into melanibe/checkpoint-saving
Apr 19, 2021
fcba209
Merge conflicts import errors
Apr 19, 2021
b50dcd8
Fix it
Apr 19, 2021
7f3f44a
Fix tests in test_scalar_model.py
Apr 19, 2021
4702280
Fix tests in test_model_util.py
Apr 20, 2021
6ff6492
Fix tests in test_scalar_model.py
Apr 20, 2021
06852fc
Fix tests in test_model_training.py
Apr 20, 2021
d98d41b
Avoid forcing the user to log epoch
Apr 20, 2021
3aebe38
Fix test_get_checkpoints
Apr 20, 2021
0693fce
Fix test_checkpoint_handling.py
Apr 20, 2021
0bb57cc
Fix callback
Apr 20, 2021
23740a4
Update CHANGELOG.md
Apr 20, 2021
310a408
Self PR review comments
Apr 20, 2021
ad488ad
Fix more tests
Apr 20, 2021
c877bc9
Merge branch 'main' of https://github.com/microsoft/InnerEye-DeepLear…
Apr 20, 2021
b9ba095
Fix argument in test
Apr 20, 2021
5ec80e2
Mypy
Apr 20, 2021
c915317
Update InnerEye-DeepLearning.iml
melanibe Apr 20, 2021
e8b6e41
Update InnerEye-DeepLearning.iml
melanibe Apr 20, 2021
f38904d
Merge branch 'main' of https://github.com/microsoft/InnerEye-DeepLear…
Apr 20, 2021
c397a57
Fix some things after merge
Apr 20, 2021
9c02781
Fix mypy errors
Apr 20, 2021
e661b93
Fix more stuff with new container API
Apr 20, 2021
59b442f
Flake8
Apr 20, 2021
3de5393
Fixing some mypy
Apr 20, 2021
2c38f2e
Fixing the thousands of mypy issues
Apr 20, 2021
f56155f
Address PR comment
Apr 20, 2021
7302282
Typo
Apr 21, 2021
68abb81
mypy fix
Apr 21, 2021
96e9c83
just style
Apr 21, 2021
a35ae27
wrong merge
Apr 21, 2021
9cc526e
local torch import
Apr 21, 2021
50b8475
get rid of some tests for speed of build
Apr 21, 2021
3d25344
Remove logging
Apr 21, 2021
145d0a5
Move the gpu logic otherwise LightningContainer do not use gpus
Apr 21, 2021
b11e14f
Update epochs and configs
Apr 21, 2021
4805426
Move SSL folder into ML folder
Apr 21, 2021
2435359
Move SSL folder into ML folder
Apr 21, 2021
d070d02
Fix pbm
Apr 21, 2021
d0b804d
Update CHANGELOG.md
Apr 21, 2021
284e7ea
Unecessary change
Apr 21, 2021
2ff58cc
Updated version of test to hopefuly make faster
Apr 21, 2021
5e8cdab
Merge branch 'melanibe/checkpoint-saving' of https://github.com/micro…
Apr 21, 2021
b7f930d
Simplify and fix it
Apr 21, 2021
d8eefa7
Merge branch 'main' of https://github.com/microsoft/InnerEye-DeepLear…
Apr 21, 2021
ae38760
just style
Apr 21, 2021
778ea9c
use 90% of the data for training
Apr 21, 2021
533a02f
add suport for chexpert
Apr 21, 2021
e39a5b1
update transforms
Apr 21, 2021
494a221
Change place of the sleep
Apr 21, 2021
f9782a3
fix it
Apr 21, 2021
e5ad233
Merge remote-tracking branch 'origin/melanibe/ssl' into melanibe/ssl
Apr 21, 2021
5c7b361
fix it
Apr 21, 2021
b58b7e4
fix it
Apr 21, 2021
bcfc981
Merge remote-tracking branch 'origin/melanibe/ssl' into melanibe/ssl
Apr 21, 2021
42baffc
Merge branch 'main' of https://github.com/microsoft/InnerEye-DeepLear…
Apr 23, 2021
652c430
Deal with num_nodes > 1
Apr 23, 2021
87cd556
Better encoder
Apr 23, 2021
3114cf8
Merge branch 'main' of https://github.com/microsoft/InnerEye-DeepLear…
Apr 26, 2021
8754701
Merge remote-tracking branch 'origin/melanibe/ssl' into melanibe/ssl
Apr 26, 2021
1d52a4f
Merge remote-tracking branch 'origin/melanibe/ssl' into melanibe/ssl
Apr 26, 2021
c16402d
extra runid should not lie in recovery folder
Apr 26, 2021
5e57fd1
Merge remote-tracking branch 'origin/melanibe/ssl' into melanibe/ssl
Apr 26, 2021
747a78c
fix
Apr 26, 2021
04f8677
Fix test with new behavior
Apr 26, 2021
3f84f87
Fix test with new behavior
Apr 26, 2021
151556a
Fix tests for ssl according to latest changes
Apr 26, 2021
4e04ee8
fastmri does not belong to amlignore
Apr 26, 2021
0170b8f
add doc for sslclassifier container
Apr 26, 2021
60d0b38
flake8
Apr 26, 2021
01a2be3
fix it?
Apr 26, 2021
748faf0
Try out another way to fix it
Apr 26, 2021
9933328
Increase time out
Apr 27, 2021
5e06e0a
Increase time out
Apr 27, 2021
7ff00c3
The number of gpu is not updated
Apr 27, 2021
b3836e5
dont change the build
Apr 27, 2021
405154f
Update InnerEye/Azure/azure_config.py
melanibe Apr 28, 2021
4066854
Refractor dataclasses
Apr 28, 2021
97e4ba0
Merge remote-tracking branch 'origin/melanibe/ssl' into melanibe/ssl
Apr 28, 2021
0fa1f6d
rename dataset consumption args
Apr 28, 2021
b36015e
DRY in azure_runner.py
Apr 28, 2021
0737501
Simplify signature
Apr 28, 2021
917d84b
First batch of PR comments
Apr 28, 2021
a8fa626
Second batch of PR comments
Apr 28, 2021
815ca20
Third batch of PR comments
Apr 28, 2021
c9610b7
Fourth batch of PR comments
Apr 28, 2021
94cf81f
Fifth batch of PR comments
Apr 28, 2021
f100745
Remove unecessary file
Apr 28, 2021
3ab91bd
Remove unecessary file
Apr 28, 2021
5711e85
Merge branch 'main' of https://github.com/microsoft/InnerEye-DeepLear…
Apr 28, 2021
cc8b36a
Fix mypy
Apr 28, 2021
ebfe7f9
Adding tests for InnerEyeVisionDataModule
Apr 28, 2021
4fca80e
Fix errors
Apr 28, 2021
1402d1a
Better API
Apr 29, 2021
0aa1414
Transform in new version where not pickable anymore
Apr 29, 2021
1a7f88f
Create shorthand for double inheritance
Apr 29, 2021
30702a0
fix mypy mess with transforms_utils.py
Apr 29, 2021
c6c8aa4
Add test for _get_transforms for CXR
Apr 29, 2021
b23e323
Add test for _get_transforms for Cifar
Apr 29, 2021
3c45de0
flake8
Apr 29, 2021
ea76d84
fix test
Apr 29, 2021
28f1602
fix test
Apr 29, 2021
ddcc9be
fix test
Apr 29, 2021
6c24a74
Adding a test for SSLEncoder init and output_dim computation
Apr 29, 2021
7ad1afc
Adding a test for 7x7 conv flag in encoder definition
Apr 29, 2021
c86d775
Update and new test
Apr 29, 2021
5638ac9
Update and new test
Apr 29, 2021
d43772f
Clean up
Apr 29, 2021
2b7014a
Merge branch 'main' of https://github.com/microsoft/InnerEye-DeepLear…
Apr 29, 2021
a100d6a
Mypy
Apr 29, 2021
3b8e6b0
fix the test
Apr 29, 2021
47b667f
Fix it
Apr 29, 2021
e84dbc8
Docstring was missing
Apr 29, 2021
797173f
Update CHANGELOG.md
Apr 29, 2021
9ba7e89
Update doc string
Apr 29, 2021
725d8cb
Outdated doc string
Apr 29, 2021
63ce329
Point user to download instructions
Apr 29, 2021
1b3ca79
Make the InnerEyeReturnIndex more versatile
Apr 29, 2021
0937fc5
Rename extra_run_recovery_id to pretraining_run_recovery_id
Apr 29, 2021
3b913b3
Make exclusion of lateral scan optional
Apr 29, 2021
baf5e31
readme remove update
Apr 30, 2021
81c84f7
skip tests on windows (too slow)
Apr 30, 2021
372cce2
First init
Apr 30, 2021
6a4f551
Add test for gaussian noise
Apr 30, 2021
d9752ea
Add new tests
Apr 30, 2021
b4b4735
Add test for add channel
Apr 30, 2021
7295e4f
This API makes more sense
Apr 30, 2021
08cfca1
This API makes more sense
Apr 30, 2021
9a4983d
Add yet another transform test
Apr 30, 2021
eac5fe2
skipif badly formatted
May 4, 2021
c62191a
mypy
May 4, 2021
0a9e318
Fix the problem
May 4, 2021
53240c9
Merge branch 'main' of https://github.com/microsoft/InnerEye-DeepLear…
May 4, 2021
ad96366
Fix build
May 4, 2021
dda04c8
Adress some PR comments
May 4, 2021
f4a35bb
Adress some PR comments
May 4, 2021
e5e6448
PR comments
May 5, 2021
c6233b5
Flake8
May 5, 2021
efc06f1
Flake8
May 5, 2021
95a33d5
Update
May 5, 2021
4c63d4e
Update
May 5, 2021
b110f71
PR configs
May 6, 2021
e348325
Simplify dataset mounting
May 6, 2021
7f361a4
Rename metrics
May 6, 2021
060d59b
Revert "Simplify dataset mounting"
May 6, 2021
8323a99
Fix configs
May 6, 2021
d8a863d
Try out with sorting
May 6, 2021
f2d82a2
Make the metrics file better and independent of order of files in fil…
May 6, 2021
3281741
Push changes to the classifier
May 7, 2021
2debf84
Metric name
May 7, 2021
e5ba80f
Update links
May 7, 2021
706e517
Flake8
May 7, 2021
0c492b1
Update keyword
May 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .amlignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ TestsOutsidePackage/azureml-models
tensorboard_runs
InnerEyeTestVariables.txt
InnerEyePrivateSettings.yml
cifar-10-batches-py
cifar-100-python
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y
- ([#450](https://github.com/microsoft/InnerEye-DeepLearning/pull/450)) Adds the metric "Accuracy at threshold 0.5" to the classification report (`classification_crossval_report.ipynb`).
- ([#451](https://github.com/microsoft/InnerEye-DeepLearning/pull/451)) Write a file `model_outputs.csv` with columns
`subject`, `prediction_target`, `label`, `model_output` and `cross_validation_split_index`. This file is not written out for sequence models.
- ([#440](https://github.com/microsoft/InnerEye-DeepLearning/pull/440)) Added support for training of self-supervised
models (BYOL and SimCLR) based on the bring-your-own-model framework. Providing examples configurations for training
of SSL models on CIFAR10/100 datasets as well as for chest-x-ray datasets such as NIH CHest-Xray or RSNA Pneumonia
Detection Challenge datasets. See
[SSL doc](https://github.com/microsoft/InnerEye-DeepLearning/blob/main/docs/self_supervised_models.md) for more
details.

### Changed

Expand All @@ -76,7 +82,7 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y
- ([#432](https://github.com/microsoft/InnerEye-DeepLearning/pull/432)) Upgraded to PyTorch-Lightning 1.2.7. Add
end-to-end test for classification cross-validation. WARNING: upgrade PL version causes hanging of multi-node
training.
- ([#437])(https://github.com/microsoft/InnerEye-DeepLearning/pull/437)) Upgrade to PyTorch-Lightning 1.2.8.
- ([#437](https://github.com/microsoft/InnerEye-DeepLearning/pull/437)) Upgrade to PyTorch-Lightning 1.2.8.
- ([#439](https://github.com/microsoft/InnerEye-DeepLearning/pull/439)) Recovery checkpoints are now
named `recovery_epoch=x.ckpt` instead of `recovery.ckpt` or `recovery-v0.ckpt`.
- ([#451](https://github.com/microsoft/InnerEye-DeepLearning/pull/451)) Change the signature for function `generate_custom_report`
Expand Down
6 changes: 6 additions & 0 deletions InnerEye/Azure/azure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class AzureConfig(GenericConfig):
"('--pytest_mark gpu' will run all tests marked with 'pytest.mark.gpu')")
run_recovery_id: str = param.String(doc="A run recovery id string in the form 'experiment name:run id'"
" to use for inference or recovering a model training run.")
pretraining_run_recovery_id: str = param.String(default=None,
allow_None=True,
doc="Extra run recovery id to download checkpoints from,"
"for custom modules (e.g. for loading pretrained weights)."
"Warning: this argument will be ignored for InnerEyeContainer"
"models.")
experiment_name: str = param.String(doc="If provided, use this string as the name of the AzureML experiment. "
"If not provided, create the experiment off the git branch name.")
build_number: int = param.Integer(0, doc="The numeric ID of the Azure pipeline that triggered this training run.")
Expand Down
42 changes: 28 additions & 14 deletions InnerEye/Azure/azure_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from InnerEye.Azure import azure_util
from InnerEye.Azure.azure_config import AzureConfig, ParserResult, SourceConfig
from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, RUN_RECOVERY_FROM_ID_KEY_NAME, \
RUN_RECOVERY_ID_KEY_NAME, \
is_offline_run_context, merge_conda_dependencies
RUN_RECOVERY_ID_KEY_NAME, is_offline_run_context, merge_conda_dependencies
from InnerEye.Azure.secrets_handling import read_all_settings
from InnerEye.Azure.tensorboard_monitor import AMLTensorBoardMonitorConfig, monitor
from InnerEye.Common.generic_parsing import GenericConfig
Expand All @@ -43,13 +42,17 @@

def submit_to_azureml(azure_config: AzureConfig,
source_config: SourceConfig,
azure_dataset_id: str) -> Run:
azure_dataset_id: str,
extra_azure_dataset_ids: List[str]) -> Run:
"""
The main entry point. It creates an AzureML workspace if needed, submits an experiment using the code
as specified in source_config, and waits for completion if needed.
melanibe marked this conversation as resolved.
Show resolved Hide resolved
:param azure_config: azure related configurations to setup valid workspace
:param source_config: The information about which code should be submitted, and which arguments should be used.
:param azure_dataset_id: The name of the dataset on blob storage to be used for this run.
:param extra_azure_dataset_ids: A list of additional dataset names on blob storage to be used for this run. This
will be ignore for InnerEyeContainer models, may only be used by custom LightningContainer (see bring your own model
and self-supervised training documentation).
"""
azure_run: Optional[Run] = None

Expand All @@ -66,7 +69,7 @@ def interrupt_handler(signal: int, _: Any) -> None:
for s in [signal.SIGINT, signal.SIGTERM]:
signal.signal(s, interrupt_handler)
# create train/test experiment
azure_run = create_and_submit_experiment(azure_config, source_config, azure_dataset_id)
azure_run = create_and_submit_experiment(azure_config, source_config, azure_dataset_id, extra_azure_dataset_ids)

if azure_config.wait_for_completion:
# We want the job output to be visible on the console, but the program should not exit if the
Expand Down Expand Up @@ -121,18 +124,22 @@ def create_experiment_name(azure_config: AzureConfig) -> str:
def create_and_submit_experiment(
azure_config: AzureConfig,
source_config: SourceConfig,
azure_dataset_id: str) -> Run:
azure_dataset_id: str,
extra_azure_dataset_ids: List[str]) -> Run:
"""
melanibe marked this conversation as resolved.
Show resolved Hide resolved
Creates an AzureML experiment in the workspace and submits it for execution.
:param azure_config: azure related configurations to setup valid workspace
:param source_config: The information about which code should be submitted, and which arguments should be used.
:param azure_dataset_id: The name of the dataset in blob storage to be used for this run.
:param extra_azure_dataset_ids: A list of additional dataset names on blob storage to be used for this run. This
will be ignore for InnerEyeContainer models, may only be used by custom LightningContainer (see bring your own model
and self-supervised training documentation).
:returns: Run object for the submitted AzureML run
"""
workspace = azure_config.get_workspace()
experiment_name = create_experiment_name(azure_config)
exp = Experiment(workspace=workspace, name=azure_util.to_azure_friendly_string(experiment_name))
script_run_config = create_run_config(azure_config, source_config, azure_dataset_id)
script_run_config = create_run_config(azure_config, source_config, azure_dataset_id, extra_azure_dataset_ids)

# submit a training/testing run associated with the experiment
run: Run = exp.submit(script_run_config)
Expand Down Expand Up @@ -273,40 +280,47 @@ def get_or_create_python_environment(azure_config: AzureConfig,
return env


def get_dataset_consumption(azure_config: AzureConfig, azure_dataset_id: str) -> DatasetConsumptionConfig:
def get_dataset_consumption(azure_config: AzureConfig,
azure_dataset_id: str,
dataset_index: int = 0) -> DatasetConsumptionConfig:
"""
Creates a configuration for using an AzureML dataset inside of an AzureML run. This will make the AzureML
dataset with given name available as a named input, using INPUT_DATA_KEY as the key.
:param azure_config: azure related configurations to use for model scale-out behaviour
:param azure_dataset_id: The name of the dataset in blob storage to be used for this run. This can be an empty
string to not use any datasets.
:param dataset_index: suffix for the dataset name, dataset name will be set to INPUT_DATA_KEY_idx
"""
azureml_dataset = get_or_create_dataset(azure_config, azure_dataset_id=azure_dataset_id)
if not azureml_dataset:
raise ValueError(f"AzureML dataset {azure_dataset_id} could not be found or created.")
named_input = azureml_dataset.as_named_input(INPUT_DATA_KEY)
named_input = azureml_dataset.as_named_input(f"{INPUT_DATA_KEY}_{dataset_index}")
return named_input.as_mount() if azure_config.use_dataset_mount else named_input.as_download()


def create_run_config(azure_config: AzureConfig,
source_config: SourceConfig,
azure_dataset_id: str = "",
extra_azure_dataset_ids: List[str] = [],
environment_name: str = "") -> ScriptRunConfig:
"""
Creates a configuration to run the InnerEye training script in AzureML.
:param azure_config: azure related configurations to use for model scale-out behaviour
:param source_config: configurations for model execution, such as name and execution mode
:param azure_dataset_id: The name of the dataset in blob storage to be used for this run. This can be an empty
string to not use any datasets.
:param extra_azure_dataset_ids: List of extra datasets in blob storage to be used for this run. This can be empty.
:param environment_name: If specified, try to retrieve the existing Python environment with this name. If that
is not found, create one from the Conda files provided in `source_config`. This parameter is meant to be used
when running inference for an existing model.
:return: The configured script run.
"""
if azure_dataset_id:
dataset_consumption = get_dataset_consumption(azure_config, azure_dataset_id)
else:
dataset_consumption = None
dataset_consumptions = {}
all_dataset_ids = [azure_dataset_id] + extra_azure_dataset_ids if azure_dataset_id else extra_azure_dataset_ids
for i, dataset_id in enumerate(all_dataset_ids):
dataset_consumption = get_dataset_consumption(azure_config, dataset_id, i)
dataset_consumptions.update({dataset_consumption.name: dataset_consumption})

# AzureML seems to sometimes expect the entry script path in Linux format, hence convert to posix path
entry_script_relative_path = source_config.entry_script.relative_to(source_config.root_folder).as_posix()
logging.info(f"Entry script {entry_script_relative_path} ({source_config.entry_script} relative to "
Expand All @@ -329,8 +343,8 @@ def create_run_config(azure_config: AzureConfig,
run_config.framework = "Python"
run_config.communicator = "IntelMpi"
run_config.node_count = distributed_job_config.node_count
if dataset_consumption:
run_config.data = {dataset_consumption.name: dataset_consumption}
if len(dataset_consumptions) > 0:
run_config.data = dataset_consumptions
# Use blob storage for storing the source, rather than the FileShares section of the storage account.
run_config.source_directory_data_store = workspace.datastores.get(WORKSPACE_DEFAULT_BLOB_STORE_NAME).name
script_run_config = ScriptRunConfig(
Expand Down
30 changes: 30 additions & 0 deletions InnerEye/ML/SSL/datamodules_and_datasets/cifar_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

from torchvision.datasets import CIFAR10, CIFAR100

from InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils import InnerEyeDataClassBaseWithReturnIndex


class InnerEyeCIFAR10(InnerEyeDataClassBaseWithReturnIndex, CIFAR10):
"""
Wrapper class around torchvision CIFAR10 class to optionally return the
index on top of the image and the label in __getitem__ as well as defining num_classes property.
"""

@property
def num_classes(self) -> int:
return 10


class InnerEyeCIFAR100(InnerEyeDataClassBaseWithReturnIndex, CIFAR100):
"""
Wrapper class around torchvision CIFAR100 class class to optionally return the
index on top of the image and the label in __getitem__ as well as defining num_classes property.
"""

@property
def num_classes(self) -> int:
return 100
Loading