This repository has been archived by the owner on Mar 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 142
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add histopathology module and add hi-ml as submodule (#603)
* Add hi-ml submodule * Add Histopathology module * Add tests * update azure scripts * Address PR comments * mypy * Add changelog * Add hi-ml submodule Add Histopathology module Add tests update azure scripts Address PR comments mypy Add changelog Changes label * turn off logging outside of AML for now to get tests to pass * flake8 * flake8 * flake8 * Bug in init file * eclude histopathology datasets module from amlignore * exclude histopathology datasets module from amlignore * exclude histopathology datasets module from amlignore * exclude histopathology datasets module from amlignore * exclude histopathology datasets module from gitignore * exclude histopathology datasets module from gitignore * Address PR comments * Address PR comments * Minor change to requeue build
- Loading branch information
Showing
41 changed files
with
2,401 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
[submodule "fastMRI"] | ||
path = fastMRI | ||
url = https://github.com/facebookresearch/fastMRI | ||
[submodule "hi-ml"] | ||
path = hi-ml | ||
url = https://github.com/microsoft/hi-ml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import pickle | ||
from enum import Enum | ||
from pathlib import Path | ||
from typing import Any, Callable, Optional, Sequence, Tuple, Union | ||
|
||
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset | ||
from pytorch_lightning import LightningDataModule | ||
from torch.utils.data import DataLoader | ||
|
||
from health_ml.utils.bag_utils import BagDataset, multibag_collate | ||
from health_ml.utils.common_utils import _create_generator | ||
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset | ||
from InnerEye.ML.Histopathology.models.transforms import LoadTilesBatchd | ||
|
||
|
||
class CacheMode(Enum): | ||
NONE = 'none' | ||
MEMORY = 'memory' | ||
DISK = 'disk' | ||
|
||
|
||
class TilesDataModule(LightningDataModule): | ||
"""Base class to load the tiles of a dataset as train, val, test sets""" | ||
|
||
def __init__(self, root_path: Path, max_bag_size: int = 0, batch_size: int = 1, | ||
seed: Optional[int] = None, transform: Optional[Callable] = None, | ||
cache_mode: CacheMode = CacheMode.NONE, save_precache: bool = False, | ||
cache_dir: Optional[Path] = None, | ||
number_of_cross_validation_splits: int = 0, | ||
cross_validation_split_index: int = 0) -> None: | ||
""" | ||
:param root_path: Root directory of the source dataset. | ||
:param max_bag_size: Upper bound on number of tiles in each loaded bag. If 0 (default), | ||
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield | ||
random subsets of instances. | ||
:param batch_size: Number of slides to load per batch. | ||
:param seed: pseudorandom number generator seed to use for shuffling instances and bags. Note that randomness in | ||
train/val/test splits is handled independently in `get_splits()`. (default: `None`) | ||
:param transform: A transform to apply to the source tiles dataset, or a composition of | ||
transforms using `monai.transforms.Compose`. By default (`None`), applies `LoadTilesBatchd`. | ||
:param cache_mode: The type of caching to perform, i.e. whether the results of all | ||
transforms up to the first randomised one should be computed only once and reused in | ||
subsequent iterations: | ||
- `MEMORY`: the entire transformed dataset is kept in memory for fastest access; | ||
- `DISK`: each transformed sample is saved to disk and loaded on-demand; | ||
- `NONE` (default): no caching is performed. | ||
:param save_precache: Whether to pre-cache the entire transformed dataset upfront and save | ||
it to disk. This is done once in `prepare_data()` only on the local rank-0 process, so | ||
multiple processes can afterwards access the same cache without contention in DDP settings. | ||
:param cache_dir: The directory onto which to cache data if caching is enabled. | ||
:param number_of_cross_validation_splits: Number of folds to perform. | ||
:param cross_validation_split_index: Index of the cross validation split to be performed. | ||
""" | ||
if save_precache and cache_mode is CacheMode.NONE: | ||
raise ValueError("Can only pre-cache if caching is enabled") | ||
if save_precache and cache_dir is None: | ||
raise ValueError("A cache directory is required for pre-caching") | ||
if cache_mode is CacheMode.DISK and cache_dir is None: | ||
raise ValueError("A cache directory is required for on-disk caching") | ||
super().__init__() | ||
|
||
self.root_path = root_path | ||
self.max_bag_size = max_bag_size | ||
self.transform = transform | ||
self.cache_mode = cache_mode | ||
self.save_precache = save_precache | ||
self.cache_dir = cache_dir | ||
self.batch_size = batch_size | ||
self.number_of_cross_validation_splits = number_of_cross_validation_splits | ||
self.cross_validation_split_index = cross_validation_split_index | ||
self.train_dataset, self.val_dataset, self.test_dataset = self.get_splits() | ||
self.class_weights = self.train_dataset.get_class_weights() | ||
self.seed = seed | ||
|
||
def get_splits(self) -> Tuple[TilesDataset, TilesDataset, TilesDataset]: | ||
"""Create the training, validation, and test datasets""" | ||
raise NotImplementedError | ||
|
||
def prepare_data(self) -> None: | ||
if self.save_precache: | ||
self._load_dataset(self.train_dataset, stage='train', shuffle=True) | ||
self._load_dataset(self.val_dataset, stage='val', shuffle=True) | ||
self._load_dataset(self.test_dataset, stage='test', shuffle=True) | ||
|
||
def _dataset_pickle_path(self, stage: str) -> Optional[Path]: | ||
if self.cache_dir is None: | ||
return None | ||
return self.cache_dir / f"{stage}_dataset.pkl" | ||
|
||
def _load_dataset(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool) -> Dataset: | ||
dataset_pickle_path = self._dataset_pickle_path(stage) | ||
|
||
if dataset_pickle_path and dataset_pickle_path.exists(): | ||
with dataset_pickle_path.open('rb') as f: | ||
return pickle.load(f) | ||
|
||
generator = _create_generator(self.seed) | ||
bag_dataset = BagDataset(tiles_dataset, # type: ignore | ||
bag_ids=tiles_dataset.slide_ids, | ||
max_bag_size=self.max_bag_size, | ||
shuffle_samples=shuffle, | ||
generator=generator) | ||
transform = self.transform or LoadTilesBatchd(tiles_dataset.IMAGE_COLUMN) | ||
|
||
# Save and restore PRNG state for consistency across (pre-)caching options | ||
generator_state = generator.get_state() | ||
transformed_bag_dataset = self._get_transformed_dataset(bag_dataset, transform) # type: ignore | ||
generator.set_state(generator_state) | ||
|
||
if dataset_pickle_path: | ||
dataset_pickle_path.parent.mkdir(parents=True, exist_ok=True) | ||
with dataset_pickle_path.open('wb') as f: | ||
pickle.dump(transformed_bag_dataset, f) | ||
|
||
return transformed_bag_dataset | ||
|
||
def _get_transformed_dataset(self, base_dataset: BagDataset, | ||
transform: Union[Sequence[Callable], Callable]) -> Dataset: | ||
if self.cache_mode is CacheMode.MEMORY: | ||
dataset = CacheDataset(base_dataset, transform, num_workers=1) # type: ignore | ||
elif self.cache_mode is CacheMode.DISK: | ||
dataset = PersistentDataset(base_dataset, transform, cache_dir=self.cache_dir) # type: ignore | ||
if self.save_precache: | ||
import tqdm # TODO: Make optional | ||
|
||
for i in tqdm.trange(len(dataset), desc="Loading dataset"): | ||
dataset[i] # empty loop to pre-compute all transformed samples | ||
else: | ||
dataset = Dataset(base_dataset, transform) # type: ignore | ||
return dataset | ||
|
||
def _get_dataloader(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool, | ||
**dataloader_kwargs: Any) -> DataLoader: | ||
transformed_bag_dataset = self._load_dataset(tiles_dataset, stage=stage, shuffle=shuffle) | ||
bag_dataset: BagDataset = transformed_bag_dataset.data # type: ignore | ||
generator = bag_dataset.bag_sampler.generator | ||
return DataLoader(transformed_bag_dataset, batch_size=self.batch_size, | ||
collate_fn=multibag_collate, shuffle=shuffle, generator=generator, | ||
pin_memory=False, # disable pinning as loaded data may already be on GPU | ||
**dataloader_kwargs) | ||
|
||
def train_dataloader(self) -> DataLoader: | ||
return self._get_dataloader(self.train_dataset, 'train', shuffle=True) | ||
|
||
def val_dataloader(self) -> DataLoader: | ||
return self._get_dataloader(self.val_dataset, 'val', shuffle=True) | ||
|
||
def test_dataloader(self) -> DataLoader: | ||
return self._get_dataloader(self.test_dataset, 'test', shuffle=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Tuple | ||
|
||
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule | ||
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset | ||
from InnerEye.ML.utils.split_dataset import DatasetSplits | ||
|
||
|
||
class PandaTilesDataModule(TilesDataModule): | ||
""" PandaTilesDataModule is the child class of TilesDataModule specific to PANDA dataset | ||
Method get_splits() returns the train, val, test splits from the PANDA dataset | ||
""" | ||
|
||
def get_splits(self) -> Tuple[PandaTilesDataset, PandaTilesDataset, PandaTilesDataset]: | ||
dataset = PandaTilesDataset(self.root_path) | ||
splits = DatasetSplits.from_proportions(dataset.dataset_df.reset_index(), | ||
proportion_train=.8, | ||
proportion_test=.1, | ||
proportion_val=.1, | ||
subject_column=dataset.TILE_ID_COLUMN, | ||
group_column=dataset.SLIDE_ID_COLUMN) | ||
return (PandaTilesDataset(self.root_path, dataset_df=splits.train), | ||
PandaTilesDataset(self.root_path, dataset_df=splits.val), | ||
PandaTilesDataset(self.root_path, dataset_df=splits.test)) |
33 changes: 33 additions & 0 deletions
33
InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import Tuple, Any | ||
|
||
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule | ||
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset | ||
from InnerEye.ML.utils.split_dataset import DatasetSplits | ||
|
||
|
||
class TcgaCrckTilesDataModule(TilesDataModule): | ||
""" TcgaCrckTilesDataModule is the child class of TilesDataModule specific to TCGA-Crck dataset | ||
Method get_splits() returns the train, val, test splits from the TCGA-Crck dataset | ||
Methods train_dataloader(), val_dataloader() and test_dataloader() override the base class methods for bag loading | ||
""" | ||
|
||
def __init__(self, **kwargs: Any) -> None: | ||
super().__init__(**kwargs) | ||
|
||
def get_splits(self) -> Tuple[TcgaCrck_TilesDataset, TcgaCrck_TilesDataset, TcgaCrck_TilesDataset]: | ||
trainval_dataset = TcgaCrck_TilesDataset(self.root_path, train=True) | ||
splits = DatasetSplits.from_proportions(trainval_dataset.dataset_df.reset_index(), | ||
proportion_train=0.8, | ||
proportion_test=0.0, | ||
proportion_val=0.2, | ||
subject_column=trainval_dataset.TILE_ID_COLUMN, | ||
group_column=trainval_dataset.SLIDE_ID_COLUMN, | ||
random_seed=5) | ||
|
||
if self.number_of_cross_validation_splits > 1: | ||
# Function get_k_fold_cross_validation_splits() will concatenate train and val splits | ||
splits = splits.get_k_fold_cross_validation_splits(self.number_of_cross_validation_splits)[self.cross_validation_split_index] | ||
|
||
return (TcgaCrck_TilesDataset(self.root_path, dataset_df=splits.train), | ||
TcgaCrck_TilesDataset(self.root_path, dataset_df=splits.val), | ||
TcgaCrck_TilesDataset(self.root_path, train=False)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from pathlib import Path | ||
from typing import Any, Dict, Optional, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from sklearn.utils.class_weight import compute_class_weight | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class TilesDataset(Dataset): | ||
"""Base class for datasets of WSI tiles, iterating dictionaries of image paths and metadata. | ||
:param TILE_ID_COLUMN: CSV column name for tile ID. | ||
:param SLIDE_ID_COLUMN: CSV column name for slide ID. | ||
:param IMAGE_COLUMN: CSV column name for relative path to image file. | ||
:param PATH_COLUMN: CSV column name for relative path to image file. Replicated to propagate the path to the batch. | ||
:param LABEL_COLUMN: CSV column name for tile label. | ||
:param SPLIT_COLUMN: CSV column name for train/test split (optional). | ||
:param TILE_X_COLUMN: CSV column name for horizontal tile coordinate (optional). | ||
:param TILE_Y_COLUMN: CSV column name for vertical tile coordinate (optional). | ||
:param TRAIN_SPLIT_LABEL: Value used to indicate the training split in `SPLIT_COLUMN`. | ||
:param TEST_SPLIT_LABEL: Value used to indicate the test split in `SPLIT_COLUMN`. | ||
:param DEFAULT_CSV_FILENAME: Default name of the dataset CSV at the dataset rood directory. | ||
:param N_CLASSES: Number of classes indexed in `LABEL_COLUMN`. | ||
""" | ||
TILE_ID_COLUMN: str = 'tile_id' | ||
SLIDE_ID_COLUMN: str = 'slide_id' | ||
IMAGE_COLUMN: str = 'image' | ||
PATH_COLUMN: str = 'image_path' | ||
LABEL_COLUMN: str = 'label' | ||
SPLIT_COLUMN: Optional[str] = 'split' | ||
TILE_X_COLUMN: Optional[str] = 'tile_x' | ||
TILE_Y_COLUMN: Optional[str] = 'tile_y' | ||
|
||
TRAIN_SPLIT_LABEL: str = 'train' | ||
TEST_SPLIT_LABEL: str = 'test' | ||
|
||
DEFAULT_CSV_FILENAME: str = "dataset.csv" | ||
|
||
N_CLASSES: int = 1 # binary classification by default | ||
|
||
def __init__(self, | ||
root: Union[str, Path], | ||
dataset_csv: Optional[Union[str, Path]] = None, | ||
dataset_df: Optional[pd.DataFrame] = None, | ||
train: Optional[bool] = None) -> None: | ||
""" | ||
:param root: Root directory of the dataset. | ||
:param dataset_csv: Full path to a dataset CSV file, containing at least | ||
`TILE_ID_COLUMN`, `SLIDE_ID_COLUMN`, and `IMAGE_COLUMN`. If omitted, the CSV will be read | ||
from `"{root}/{DEFAULT_CSV_FILENAME}"`. | ||
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read | ||
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`. | ||
:param train: If `True`, loads only the training split (resp. `False` for test split). By | ||
default (`None`), loads the entire dataset as-is. | ||
""" | ||
if self.SPLIT_COLUMN is None and train is not None: | ||
raise ValueError("Train/test split was specified but dataset has no split column") | ||
|
||
self.root_dir = Path(root) | ||
|
||
if dataset_df is not None: | ||
self.dataset_csv = None | ||
else: | ||
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME | ||
dataset_df = pd.read_csv(self.dataset_csv) | ||
|
||
columns = [self.SLIDE_ID_COLUMN, self.IMAGE_COLUMN, self.LABEL_COLUMN, self.LABEL_COLUMN, | ||
self.SPLIT_COLUMN, self.TILE_X_COLUMN, self.TILE_Y_COLUMN] | ||
for column in columns: | ||
if column is not None and column not in dataset_df.columns: | ||
raise ValueError(f"Expected column '{column}' not found in the dataframe") | ||
|
||
dataset_df = dataset_df.set_index(self.TILE_ID_COLUMN) | ||
if train is None: | ||
self.dataset_df = dataset_df | ||
else: | ||
split = self.TRAIN_SPLIT_LABEL if train else self.TEST_SPLIT_LABEL | ||
self.dataset_df = dataset_df[dataset_df[self.SPLIT_COLUMN] == split] | ||
|
||
def __len__(self) -> int: | ||
return self.dataset_df.shape[0] | ||
|
||
def __getitem__(self, index: int) -> Dict[str, Any]: | ||
tile_id = self.dataset_df.index[index] | ||
sample = { | ||
self.TILE_ID_COLUMN: tile_id, | ||
**self.dataset_df.loc[tile_id].to_dict() | ||
} | ||
sample[self.IMAGE_COLUMN] = str(self.root_dir / sample.pop(self.IMAGE_COLUMN)) | ||
# we're replicating this column because we want to propagate the path to the batch | ||
sample[self.PATH_COLUMN] = sample[self.IMAGE_COLUMN] | ||
return sample | ||
|
||
@property | ||
def slide_ids(self) -> pd.Series: | ||
return self.dataset_df[self.SLIDE_ID_COLUMN] | ||
|
||
def get_slide_labels(self) -> pd.Series: | ||
return self.dataset_df.groupby(self.SLIDE_ID_COLUMN)[self.LABEL_COLUMN].agg(pd.Series.mode) | ||
|
||
def get_class_weights(self) -> torch.Tensor: | ||
slide_labels = self.get_slide_labels() | ||
classes = np.unique(slide_labels) | ||
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=slide_labels) | ||
return torch.as_tensor(class_weights) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
PANDA_TILES_DATASET_ID = "PANDA_tiles" | ||
TCGA_CRCK_DATASET_ID = "TCGA-CRCk" | ||
TCGA_PRAD_DATASET_ID = "TCGA-PRAD" | ||
|
||
DEFAULT_DATASET_LOCATION = "/tmp/datasets/" | ||
PANDA_TILES_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_TILES_DATASET_ID | ||
TCGA_CRCK_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_CRCK_DATASET_ID | ||
TCGA_PRAD_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_PRAD_DATASET_ID |
Oops, something went wrong.