Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Vsalva/deepmil panda (#619)
Browse files Browse the repository at this point in the history
* adding deepmilpanda container
  • Loading branch information
vale-salvatelli committed Dec 14, 2021
1 parent 212e65c commit 4aa84b9
Show file tree
Hide file tree
Showing 11 changed files with 390 additions and 127 deletions.
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,3 @@ repos:
rev: v1.5.7
hooks:
- id: autopep8

- repo: https://github.com/ambv/black
rev: 21.9b0
hooks:
- id: black
language_version: python3.7
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs that run in AzureML.
- ([#554](https://github.com/microsoft/InnerEye-DeepLearning/pull/554)) Added a parameter `pretraining_dataset_id` to
`NIH_COVID_BYOL` to specify the name of the SSL training dataset.
- ([#560](https://github.com/microsoft/InnerEye-DeepLearning/pull/560)) Added pre-commit hooks.
-([#619](https://github.com/microsoft/InnerEye-DeepLearning/pull/619)) Add DeepMIL PANDA
- ([#559](https://github.com/microsoft/InnerEye-DeepLearning/pull/559)) Adding the accompanying code for the ["Active label cleaning: Improving dataset quality under resource constraints"](https://arxiv.org/abs/2109.00574) paper. The code can be found in the [InnerEye-DataQuality](InnerEye-DataQuality/README.md) subfolder. It provides tools for training noise robust models, running label cleaning simulation and loading our label cleaning benchmark datasets.
- ([#589](https://github.com/microsoft/InnerEye-DeepLearning/pull/589)) Add `LightningContainer.update_azure_config()`
hook to enable overriding `AzureConfig` parameters from a container (e.g. `experiment_name`, `cluster`, `num_nodes`).
Expand Down
124 changes: 85 additions & 39 deletions InnerEye/Azure/azure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,20 @@ def split_recovery_id(id: str) -> Tuple[str, str]:
"""
components = id.strip().split(EXPERIMENT_RUN_SEPARATOR)
if len(components) > 2:
raise ValueError("recovery_id must be in the format: 'experiment_name:run_id', but got: {}".format(id))
raise ValueError(
"recovery_id must be in the format: 'experiment_name:run_id', but got: {}".format(
id
)
)
elif len(components) == 2:
return components[0], components[1]
else:
recovery_id_regex = r"^(\w+)_\d+_[0-9a-f]+$|^(\w+)_\d+$"
match = re.match(recovery_id_regex, id)
if not match:
raise ValueError("The recovery ID was not in the expected format: {}".format(id))
raise ValueError(
"The recovery ID was not in the expected format: {}".format(id)
)
return (match.group(1) or match.group(2)), id


Expand All @@ -77,9 +83,15 @@ def fetch_run(workspace: Workspace, run_recovery_id: str) -> Run:
try:
experiment_to_recover = Experiment(workspace, experiment)
except Exception as ex:
raise Exception(f"Unable to retrieve run {run} in experiment {experiment}: {str(ex)}")
raise Exception(
f"Unable to retrieve run {run} in experiment {experiment}: {str(ex)}"
)
run_to_recover = fetch_run_for_experiment(experiment_to_recover, run)
logging.info("Fetched run #{} {} from experiment {}.".format(run, run_to_recover.number, experiment))
logging.info(
"Fetched run #{} {} from experiment {}.".format(
run, run_to_recover.number, experiment
)
)
return run_to_recover


Expand All @@ -94,9 +106,13 @@ def fetch_run_for_experiment(experiment_to_recover: Experiment, run_id: str) ->
except Exception:
available_runs = experiment_to_recover.get_runs()
available_ids = ", ".join([run.id for run in available_runs])
raise (Exception(
"Run {} not found for experiment: {}. Available runs are: {}".format(
run_id, experiment_to_recover.name, available_ids)))
raise (
Exception(
"Run {} not found for experiment: {}. Available runs are: {}".format(
run_id, experiment_to_recover.name, available_ids
)
)
)


def fetch_runs(experiment: Experiment, filters: List[str]) -> List[Run]:
Expand All @@ -116,8 +132,11 @@ def fetch_runs(experiment: Experiment, filters: List[str]) -> List[Run]:
return exp_runs


def fetch_child_runs(run: Run, status: Optional[str] = None,
expected_number_cross_validation_splits: int = 0) -> List[Run]:
def fetch_child_runs(
run: Run,
status: Optional[str] = None,
expected_number_cross_validation_splits: int = 0,
) -> List[Run]:
"""
Fetch child runs for the provided runs that have the provided AML status (or fetch all by default)
and have a run_recovery_id tag value set (this is to ignore superfluous AML infrastructure platform runs).
Expand All @@ -138,18 +157,25 @@ def fetch_child_runs(run: Run, status: Optional[str] = None,
if 0 < expected_number_cross_validation_splits != len(children_runs):
logging.warning(
f"The expected number of child runs was {expected_number_cross_validation_splits}."
f"Fetched only: {len(children_runs)} runs. Now trying to fetch them manually.")
run_ids_to_evaluate = [f"{create_run_recovery_id(run)}_{i}"
for i in range(expected_number_cross_validation_splits)]
children_runs = [fetch_run(run.experiment.workspace, id) for id in run_ids_to_evaluate]
f"Fetched only: {len(children_runs)} runs. Now trying to fetch them manually."
)
run_ids_to_evaluate = [
f"{create_run_recovery_id(run)}_{i}"
for i in range(expected_number_cross_validation_splits)
]
children_runs = [
fetch_run(run.experiment.workspace, id) for id in run_ids_to_evaluate
]
if status is not None:
children_runs = [child_run for child_run in children_runs if child_run.get_status() == status]
children_runs = [
child_run for child_run in children_runs if child_run.get_status() == status
]
return children_runs


def is_ensemble_run(run: Run) -> bool:
"""Checks if the run was an ensemble of multiple models"""
return run.get_tags().get(IS_ENSEMBLE_KEY_NAME) == 'True'
return run.get_tags().get(IS_ENSEMBLE_KEY_NAME) == "True"


def to_azure_friendly_string(x: Optional[str]) -> Optional[str]:
Expand All @@ -160,7 +186,7 @@ def to_azure_friendly_string(x: Optional[str]) -> Optional[str]:
if x is None:
return x
else:
return re.sub('_+', '_', re.sub(r'\W+', '_', x))
return re.sub("_+", "_", re.sub(r"\W+", "_", x))


def to_azure_friendly_container_path(path: Path) -> str:
Expand All @@ -178,7 +204,7 @@ def is_offline_run_context(run_context: Run) -> bool:
:param run_context: Context of the run to check
:return:
"""
return not hasattr(run_context, 'experiment')
return not hasattr(run_context, "experiment")


def get_run_context_or_default(run: Optional[Run] = None) -> Run:
Expand All @@ -199,7 +225,12 @@ def get_cross_validation_split_index(run: Run) -> int:
if is_offline_run_context(run):
return DEFAULT_CROSS_VALIDATION_SPLIT_INDEX
else:
return int(run.get_tags().get(CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, DEFAULT_CROSS_VALIDATION_SPLIT_INDEX))
return int(
run.get_tags().get(
CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY,
DEFAULT_CROSS_VALIDATION_SPLIT_INDEX,
)
)


def is_cross_validation_child_run(run: Run) -> bool:
Expand Down Expand Up @@ -256,9 +287,7 @@ def is_parent_run(run: Run) -> bool:
return PARENT_RUN_CONTEXT and run.id == PARENT_RUN_CONTEXT.id


def download_run_output_file(blob_path: Path,
destination: Path,
run: Run) -> Path:
def download_run_output_file(blob_path: Path, destination: Path, run: Run) -> Path:
"""
Downloads a single file from the run's default output directory: DEFAULT_AML_UPLOAD_DIR ("outputs").
For example, if blobs_path = "foo/bar.csv", then the run result file "outputs/foo/bar.csv" will be downloaded
Expand All @@ -270,17 +299,21 @@ def download_run_output_file(blob_path: Path,
"""
blobs_prefix = str((fixed_paths.DEFAULT_AML_UPLOAD_DIR / blob_path).as_posix())
destination = destination / blob_path.name
logging.info(f"Downloading single file from run {run.id}: {blobs_prefix} -> {str(destination)}")
logging.info(
f"Downloading single file from run {run.id}: {blobs_prefix} -> {str(destination)}"
)
try:
run.download_file(blobs_prefix, str(destination), _validate_checksum=True)
except Exception as ex:
raise ValueError(f"Unable to download file '{blobs_prefix}' from run {run.id}") from ex
raise ValueError(
f"Unable to download file '{blobs_prefix}' from run {run.id}"
) from ex
return destination


def download_run_outputs_by_prefix(blobs_prefix: Path,
destination: Path,
run: Run) -> None:
def download_run_outputs_by_prefix(
blobs_prefix: Path, destination: Path, run: Run
) -> None:
"""
Download all the blobs from the run's default output directory: DEFAULT_AML_UPLOAD_DIR ("outputs") that
have a given prefix (folder structure). When saving, the prefix string will be stripped off. For example,
Expand All @@ -291,7 +324,9 @@ def download_run_outputs_by_prefix(blobs_prefix: Path,
:param destination: Local path to save the downloaded blobs to.
"""
prefix_str = str((fixed_paths.DEFAULT_AML_UPLOAD_DIR / blobs_prefix).as_posix())
logging.info(f"Downloading multiple files from run {run.id}: {prefix_str} -> {str(destination)}")
logging.info(
f"Downloading multiple files from run {run.id}: {prefix_str} -> {str(destination)}"
)
# There is a download_files function, but that can time out when downloading several large checkpoints file
# (120sec timeout for all files).
for file in run.get_file_names():
Expand All @@ -300,10 +335,14 @@ def download_run_outputs_by_prefix(blobs_prefix: Path,
if target_path.startswith("/"):
target_path = target_path[1:]
logging.info(f"Downloading {file}")
run.download_file(file, str(destination / target_path), _validate_checksum=True)
run.download_file(
file, str(destination / target_path), _validate_checksum=True
)
else:
logging.warning(f"Skipping file {file}, because the desired prefix {prefix_str} is not aligned with "
f"the folder structure")
logging.warning(
f"Skipping file {file}, because the desired prefix {prefix_str} is not aligned with "
f"the folder structure"
)


def is_running_on_azure_agent() -> bool:
Expand All @@ -314,10 +353,9 @@ def is_running_on_azure_agent() -> bool:
return bool(os.environ.get("AGENT_OS", None))


def get_comparison_baseline_paths(outputs_folder: Path,
blob_path: Path, run: Run,
dataset_csv_file_name: str) -> \
Tuple[Optional[Path], Optional[Path]]:
def get_comparison_baseline_paths(
outputs_folder: Path, blob_path: Path, run: Run, dataset_csv_file_name: str
) -> Tuple[Optional[Path], Optional[Path]]:
run_rec_id = run.id
# We usually find dataset.csv in the same directory as metrics.csv, but we sometimes
# have to look higher up.
Expand All @@ -328,21 +366,29 @@ def get_comparison_baseline_paths(outputs_folder: Path,
for blob_path_parent in step_up_directories(blob_path):
try:
comparison_dataset_path = download_run_output_file(
blob_path_parent / dataset_csv_file_name, destination_folder, run)
blob_path_parent / dataset_csv_file_name, destination_folder, run
)
break
except (ValueError, UserErrorException):
logging.warning(f"cannot find {dataset_csv_file_name} at {blob_path_parent} in {run_rec_id}")
logging.warning(
f"cannot find {dataset_csv_file_name} at {blob_path_parent} in {run_rec_id}"
)
except NotADirectoryError:
logging.warning(f"{blob_path_parent} is not a directory")
break
if comparison_dataset_path is None:
logging.warning(f"cannot find {dataset_csv_file_name} at or above {blob_path} in {run_rec_id}")
logging.warning(
f"cannot find {dataset_csv_file_name} at or above {blob_path} in {run_rec_id}"
)
# Look for epoch_NNN/Test/metrics.csv
try:
comparison_metrics_path = download_run_output_file(
blob_path / SUBJECT_METRICS_FILE_NAME, destination_folder, run)
blob_path / SUBJECT_METRICS_FILE_NAME, destination_folder, run
)
except (ValueError, UserErrorException):
logging.warning(f"cannot find {SUBJECT_METRICS_FILE_NAME} at {blob_path} in {run_rec_id}")
logging.warning(
f"cannot find {SUBJECT_METRICS_FILE_NAME} at {blob_path} in {run_rec_id}"
)
return (comparison_dataset_path, comparison_metrics_path)


Expand Down
5 changes: 4 additions & 1 deletion InnerEye/ML/Histopathology/datamodules/panda_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

from typing import Tuple
from typing import Tuple, Any

from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
Expand All @@ -15,6 +15,9 @@ class PandaTilesDataModule(TilesDataModule):
Method get_splits() returns the train, val, test splits from the PANDA dataset
"""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

def get_splits(self) -> Tuple[PandaTilesDataset, PandaTilesDataset, PandaTilesDataset]:
dataset = PandaTilesDataset(self.root_path)
splits = DatasetSplits.from_proportions(dataset.dataset_df.reset_index(),
Expand Down
6 changes: 3 additions & 3 deletions InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ class PandaTilesDataset(TilesDataset):
SPLIT_COLUMN = None # PANDA does not have an official train/test split
N_CLASSES = 6

_RELATIVE_ROOT_FOLDER = "PANDA_tiles_20210926-135446/panda_tiles_level1_224"
_RELATIVE_ROOT_FOLDER = Path("PANDA_tiles_20210926-135446/panda_tiles_level1_224")

def __init__(self,
root: Union[str, Path],
root: Path,
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None) -> None:
super().__init__(root=Path(root) / self._RELATIVE_ROOT_FOLDER,
Expand All @@ -48,7 +48,7 @@ class PandaTilesDatasetReturnImageLabel(VisionDataset):
class label.
"""
def __init__(self,
root: Union[str, Path],
root: Path,
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None,
transform: Optional[Callable] = None,
Expand Down
13 changes: 11 additions & 2 deletions InnerEye/ML/Histopathology/models/deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

import logging
from pathlib import Path
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -166,7 +167,7 @@ def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsK
bag_labels_list = []
bag_logits_list = []
bag_attn_list = []
for bag_idx in range(len(batch[TilesDataset.LABEL_COLUMN])):
for bag_idx in range(len(batch[self.label_column])):
images = batch[TilesDataset.IMAGE_COLUMN][bag_idx]
labels = batch[self.label_column][bag_idx]
bag_labels_list.append(self.get_bag_label(labels))
Expand All @@ -177,7 +178,7 @@ def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsK
bag_labels = torch.stack(bag_labels_list).view(-1)

if self.n_classes > 1:
loss = self.loss_fn(bag_logits, bag_labels)
loss = self.loss_fn(bag_logits, bag_labels.long())
else:
loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float())

Expand All @@ -201,6 +202,14 @@ def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsK
ResultsKey.PROB: probs, ResultsKey.PRED_LABEL: preds,
ResultsKey.TRUE_LABEL: bag_labels, ResultsKey.BAG_ATTN: bag_attn_list,
ResultsKey.IMAGE: batch[TilesDataset.IMAGE_COLUMN]})

if (TilesDataset.TILE_X_COLUMN in batch.keys()) and (TilesDataset.TILE_Y_COLUMN in batch.keys()):
results.update({ResultsKey.TILE_X: batch[TilesDataset.TILE_X_COLUMN],
ResultsKey.TILE_Y: batch[TilesDataset.TILE_Y_COLUMN]}
)
else:
logging.warning("Coordinates not found in batch. If this is not expected check your input tiles dataset.")

return results

def training_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore
Expand Down
3 changes: 3 additions & 0 deletions InnerEye/ML/Histopathology/utils/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ class ResultsKey(str, Enum):
PRED_LABEL = 'pred_label'
TRUE_LABEL = 'true_label'
BAG_ATTN = 'bag_attn'
TILE_X = "x"
TILE_Y = "y"

Loading

0 comments on commit 4aa84b9

Please sign in to comment.