Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Enable tiling non-PANDA WSI datasets #621

Merged
merged 16 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Back-up PANDA tiling scripts
  • Loading branch information
dccastro committed Dec 14, 2021
commit b8e7f5205103980f6cc4fd3052657738c48f2ffb
230 changes: 230 additions & 0 deletions InnerEye/ML/Histopathology/preprocessing/create_panda_tiles_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

"""This script is specific to PANDA and is kept only for retrocompatibility.
`create_tiles_dataset.py` is the new supported way to process slide datasets.
"""
import functools
import os
import logging
import shutil
import traceback
import warnings
from pathlib import Path
from typing import Sequence, Tuple, Union

import numpy as np
import PIL
from monai.data import Dataset
from monai.data.image_reader import WSIReader
from tqdm import tqdm

from InnerEye.ML.Histopathology.preprocessing import tiling
from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId


CSV_COLUMNS = ['slide_id', 'tile_id', 'image', 'mask', 'tile_x', 'tile_y', 'occupancy',
'data_provider', 'slide_isup_grade', 'slide_gleason_score']
TMP_SUFFIX = "_tmp"

logging.basicConfig(format='%(asctime)s %(message)s', filemode='w')
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)


def select_tile(mask_tile: np.ndarray, occupancy_threshold: float) \
-> Union[Tuple[bool, float], Tuple[np.ndarray, np.ndarray]]:
if occupancy_threshold < 0. or occupancy_threshold > 1.:
raise ValueError("Tile occupancy threshold must be between 0 and 1")
foreground_mask = mask_tile > 0
occupancy = foreground_mask.mean(axis=(-2, -1))
return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze()


def get_tile_descriptor(tile_location: Sequence[int]) -> str:
return f"{tile_location[0]:05d}x_{tile_location[1]:05d}y"


def get_tile_id(slide_id: str, tile_location: Sequence[int]) -> str:
return f"{slide_id}.{get_tile_descriptor(tile_location)}"


def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image:
path.parent.mkdir(parents=True, exist_ok=True)
array_hwc = np.moveaxis(array_chw, 0, -1).astype(np.uint8).squeeze()
pil_image = PIL.Image.fromarray(array_hwc)
pil_image.convert('RGB').save(path)
return pil_image


def generate_tiles(sample: dict, tile_size: int, occupancy_threshold: float) \
-> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
image_tiles, tile_locations = tiling.tile_array_2d(sample['image'], tile_size=tile_size,
constant_values=255)
mask_tiles, _ = tiling.tile_array_2d(sample['mask'], tile_size=tile_size, constant_values=0)

selected: np.ndarray
occupancies: np.ndarray
selected, occupancies = select_tile(mask_tiles, occupancy_threshold)
n_discarded = (~selected).sum()
logging.info(f"Percentage tiles discarded: {round(selected.sum() / n_discarded * 100, 2)}")

image_tiles = image_tiles[selected]
mask_tiles = mask_tiles[selected]
tile_locations = tile_locations[selected]
occupancies = occupancies[selected]

abs_tile_locations = (sample['scale'] * tile_locations + sample['location']).astype(int)

return image_tiles, mask_tiles, abs_tile_locations, occupancies, n_discarded


# TODO refactor this to separate metadata identification from saving. We might want the metadata
# even if the saving fails
def save_tile(sample: dict, image_tile: np.ndarray, mask_tile: np.ndarray,
tile_location: Sequence[int], output_dir: Path) -> dict:
slide_id = sample['image_id']
descriptor = get_tile_descriptor(tile_location)
image_tile_filename = f"train_images/{descriptor}.png"
mask_tile_filename = f"train_label_masks/{descriptor}_mask.png"

save_image(image_tile, output_dir / image_tile_filename)
save_image(mask_tile, output_dir / mask_tile_filename)

tile_metadata = {
'slide_id': slide_id,
'tile_id': get_tile_id(slide_id, tile_location),
'image': image_tile_filename,
'mask': mask_tile_filename,
'tile_x': tile_location[0],
'tile_y': tile_location[1],
'data_provider': sample['data_provider'],
'slide_isup_grade': sample['isup_grade'],
'slide_gleason_score': sample['gleason_score'],
}

return tile_metadata


def process_slide(sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int,
output_dir: Path, tile_progress: bool = False) -> None:
slide_id = sample['image_id']
slide_dir: Path = output_dir / (slide_id + "/")
logging.info(f">>> Slide dir {slide_dir}")
if slide_dir.exists(): # already processed slide - skip
logging.info(f">>> Skipping {slide_dir} - already processed")
return
else:
try:
slide_dir.mkdir(parents=True)

dataset_csv_path = slide_dir / "dataset.csv"
dataset_csv_file = dataset_csv_path.open('w')
dataset_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header

tiles_failure = 0
failed_tiles_csv_path = slide_dir / "failed_tiles.csv"
failed_tiles_file = failed_tiles_csv_path.open('w')
failed_tiles_file.write('tile_id' + '\n')

logging.info(f"Loading slide {slide_id} ...")
loader = LoadPandaROId(WSIReader(), level=level, margin=margin)
sample = loader(sample) # load 'image' and 'mask' from disk

logging.info(f"Tiling slide {slide_id} ...")
image_tiles, mask_tiles, tile_locations, occupancies, _ = \
generate_tiles(sample, tile_size, occupancy_threshold)
n_tiles = image_tiles.shape[0]

for i in tqdm(range(n_tiles), f"Tiles ({slide_id[:6]}…)", unit="img", disable=not tile_progress):
try:
tile_metadata = save_tile(sample, image_tiles[i], mask_tiles[i], tile_locations[i],
slide_dir)
tile_metadata['occupancy'] = occupancies[i]
tile_metadata['image'] = os.path.join(slide_dir.name, tile_metadata['image'])
tile_metadata['mask'] = os.path.join(slide_dir.name, tile_metadata['mask'])
dataset_row = ','.join(str(tile_metadata[column]) for column in CSV_COLUMNS)
dataset_csv_file.write(dataset_row + '\n')
except Exception as e:
tiles_failure += 1
descriptor = get_tile_descriptor(tile_locations[i]) + '\n'
failed_tiles_file.write(descriptor)
traceback.print_exc()
warnings.warn(f"An error occurred while saving tile "
f"{get_tile_id(slide_id, tile_locations[i])}: {e}")

dataset_csv_file.close()
failed_tiles_file.close()
if tiles_failure > 0:
# TODO what we want to do with slides that have some failed tiles?
logging.warning(f"{slide_id} is incomplete. {tiles_failure} tiles failed.")
except Exception as e:
traceback.print_exc()
warnings.warn(f"An error occurred while processing slide {slide_id}: {e}")


def merge_dataset_csv_files(dataset_dir: Path) -> Path:
full_csv = dataset_dir / "dataset.csv"
# TODO change how we retrieve these filenames, probably because mounted, the operation is slow
# and it seems to find many more files
# print("List of files")
# print([str(file) + '\n' for file in dataset_dir.glob("*/dataset.csv")])
with full_csv.open('w') as full_csv_file:
# full_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header
first_file = True
for slide_csv in tqdm(dataset_dir.glob("*/dataset.csv"), desc="Merging dataset.csv", unit='file'):
logging.info(f"Merging slide {slide_csv}")
content = slide_csv.read_text()
if not first_file:
content = content[content.index('\n') + 1:] # discard header row for all but the first file
full_csv_file.write(content)
first_file = False
return full_csv


def main(panda_dir: Union[str, Path], root_output_dir: Union[str, Path], level: int, tile_size: int,
margin: int, occupancy_threshold: float, parallel: bool = False, overwrite: bool = False) -> None:

# Ignoring some types here because mypy is getting confused with the MONAI Dataset class
# to select a subsample use keyword n_slides
dataset = Dataset(PandaDataset(panda_dir)) # type: ignore

output_dir = Path(root_output_dir) / f"panda_tiles_level{level}_{tile_size}"
logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} PANDA tiles at: {output_dir}")

if overwrite and output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=not overwrite)

func = functools.partial(process_slide, level=level, margin=margin, tile_size=tile_size,
occupancy_threshold=occupancy_threshold, output_dir=output_dir,
tile_progress=not parallel)

if parallel:
import multiprocessing

pool = multiprocessing.Pool()
map_func = pool.imap_unordered # type: ignore
else:
map_func = map # type: ignore

list(tqdm(map_func(func, dataset), desc="Slides", unit="img", total=len(dataset))) # type: ignore

if parallel:
pool.close()

logging.info("Merging slide files in a single file")
merge_dataset_csv_files(output_dir)


if __name__ == '__main__':
main(panda_dir="/tmp/datasets/PANDA",
root_output_dir="/datadrive",
level=1,
tile_size=224,
margin=64,
occupancy_threshold=0.05,
parallel=True,
overwrite=False)
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

"""
This script is an example of how to use the submit_to_azure_if_needed function from the hi-ml package to run the
main pre-processing function that creates tiles from slides in the PANDA dataset. The advantage of using this script
is the ability to submit to a cluster on azureml and to have the output files directly saved as a registered dataset.

This script is specific to PANDA and is kept only for retrocompatibility.
`azure_tiles_creation.py` is the new supported way to process slide datasets.

To run execute, from inside the pre-processing folder,
python azure_tiles_creation.py --azureml

A json configuration file containing the credentials to the Azure workspace and an environment.yml file are expected
in input.

This has been tested on hi-mlv0.1.4.
"""

from pathlib import Path
import sys
import time

current_file = Path(__file__)
radiomics_root = current_file.absolute().parent.parent.parent.parent.parent
sys.path.append(str(radiomics_root))
from health_azure.himl import submit_to_azure_if_needed, DatasetConfig # noqa
from InnerEye.ML.Histopathology.preprocessing.create_panda_tiles_dataset import main # noqa

# Pre-built environment file that contains all the requirements (RadiomicsNN + histo)
# Assuming ENV_NAME is a complete environment, `conda env export -n ENV_NAME -f ENV_NAME.yml` will create the desired file
ENVIRONMENT_FILE = radiomics_root.joinpath(Path("/envs/innereyeprivatetiles.yml"))
DATASET_NAME = "PANDA_tiles"
timestr = time.strftime("%Y%m%d-%H%M%S")
folder_name = DATASET_NAME + '_' + timestr

if __name__ == '__main__':
print(f"Running {str(current_file)}")
input_dataset = DatasetConfig(name="PANDA", datastore="innereyedatasets", local_folder=Path("/tmp/datasets/PANDA"), use_mounting=True)
output_dataset = DatasetConfig(name=DATASET_NAME, datastore="innereyedatasets", local_folder=Path("/datadrive/"), use_mounting=True)
run_info = submit_to_azure_if_needed(entry_script=current_file,
snapshot_root_directory=radiomics_root,
workspace_config_file=Path("config.json"),
compute_cluster_name='training-pr-nc12', # training-nd24
default_datastore="innereyedatasets",
conda_environment_file=Path(ENVIRONMENT_FILE),
input_datasets=[input_dataset],
output_datasets=[output_dataset],
)
input_folder = run_info.input_datasets[0]
output_folder = Path(run_info.output_datasets[0], folder_name)
print(f'This will be the final ouput folder {str(output_folder)}')

main(panda_dir=str(input_folder),
root_output_dir=str(output_folder),
level=1,
tile_size=224,
margin=64,
occupancy_threshold=0.05,
parallel=True,
overwrite=False)