Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow SlideData to use existing h5path files #337

Open
wants to merge 22 commits into
base: load-data-in-workers
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Enable running pipelines on existing tiles
  • Loading branch information
tddough98 committed Oct 20, 2022
commit 8b2b20265fde083df75087530cad203ecfde0a74
152 changes: 67 additions & 85 deletions pathml/core/slide_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class SlideData:
slide_type (pathml.core.SlideType, optional): slide type specification. Must be a
:class:`~pathml.core.SlideType` object. Alternatively, slide type can be specified by using the
parameters ``stain``, ``tma``, ``rgb``, ``volumetric``, and ``time_series``.
stain (str, optional): Flag indicating type of slide stain. Must be one of [‘HE’, ‘IHC’, ‘Fluor].
stain (str, optional): Flag indicating type of slide stain. Must be one of ["HE", "IHC", "Fluor"].
Defaults to ``None``. Ignored if ``slide_type`` is specified.
platform (str, optional): Flag indicating the imaging platform (e.g. CODEX, Vectra, etc.).
Defaults to ``None``. Ignored if ``slide_type`` is specified.
Expand All @@ -74,6 +74,14 @@ class SlideData:
time_series (bool, optional): Flag indicating whether the image is a time series.
Defaults to ``None``. Ignored if ``slide_type`` is specified.
counts (anndata.AnnData): object containing counts matrix associated with image quantification
tile_size (int, optional): Size of each tile. Defaults to 256px
tile_stride (int, optional): Stride between tiles. If ``None``, uses ``tile_stride = tile_size``
for non-overlapping tiles. Defaults to ``None``.
tile_level (int, optional): Level to extract tiles from. Defaults to 0.
tile_pad (bool): How to handle chunks on the edges. If ``True``, these edge chunks will be zero-padded
symmetrically and yielded with the other chunks. If ``False``, incomplete edge chunks will be ignored.
Defaults to ``False``.
**tile_kwargs: Other arguments passed through to ``generate_tiles()`` method of the backend.
"""

def __init__(
Expand All @@ -93,6 +101,11 @@ def __init__(
time_series=None,
counts=None,
dtype=None,
tile_size=256,
tile_stride=None,
tile_level=0,
tile_pad=False,
**tile_kwargs,
):
# check inputs
assert masks is None or isinstance(
Expand Down Expand Up @@ -201,7 +214,38 @@ def __init__(
self.h5manager = pathml.core.h5managers.h5pathManager(slidedata=self)

self.masks = pathml.core.Masks(h5manager=self.h5manager, masks=masks)
self.tiles = pathml.core.Tiles(h5manager=self.h5manager, tiles=tiles)
self._tiles = (
pathml.core.Tiles(h5manager=self.h5manager, tiles=tiles)
if tiles is not None
else None
)

if tile_stride is None:
tile_stride = tile_size
elif isinstance(tile_stride, int):
tile_stride = (tile_stride, tile_stride)

self.tile_size = tile_size
self.tile_stride = tile_stride
self.tile_level = tile_level
self.tile_pad = tile_pad
self.tile_kwargs = tile_kwargs

# TODO: be careful here since we are modifying h5 outside of h5manager
# look into whether we can push this into h5manager

self.h5manager.h5["tiles"].attrs["tile_stride"] = tile_stride

@property
def tiles(self):
self._add_tiles = self._tiles is None
if self._tiles is None:
self._tiles = pathml.core.Tiles(h5manager=self.h5manager)
for tile in self._generate_tiles():
yield tile
else:
for tile in self._tiles:
yield tile

def __repr__(self):
out = []
Expand Down Expand Up @@ -241,62 +285,24 @@ def run(
pipeline,
distributed=True,
client=None,
tile_size=256,
tile_stride=None,
level=0,
tile_pad=False,
overwrite_existing_tiles=False,
write_dir=None,
**kwargs,
):
"""
Run a preprocessing pipeline on SlideData.
Tiles are generated by calling self.generate_tiles() and pipeline is applied to each tile.
Run a preprocessing pipeline on all tiles in SlideData.

Args:
pipeline (pathml.preprocessing.pipeline.Pipeline): Preprocessing pipeline.
distributed (bool): Whether to distribute model using client. Defaults to True.
client: dask.distributed client
tile_size (int, optional): Size of each tile. Defaults to 256px
tile_stride (int, optional): Stride between tiles. If ``None``, uses ``tile_stride = tile_size``
for non-overlapping tiles. Defaults to ``None``.
level (int, optional): Level to extract tiles from. Defaults to ``None``.
tile_pad (bool): How to handle chunks on the edges. If ``True``, these edge chunks will be zero-padded
symmetrically and yielded with the other chunks. If ``False``, incomplete edge chunks will be ignored.
Defaults to ``False``.
overwrite_existing_tiles (bool): Whether to overwrite existing tiles. If ``False``, running a pipeline will
fail if ``tiles is not None``. Defaults to ``False``.
write_dir (str): Path to directory to write the processed slide to. The processed SlideData object
will be written to the directory immediately after the pipeline has completed running.
The filepath will default to "<write_dir>/<slide.name>.h5path. Defaults to ``None``.
**kwargs: Other arguments passed through to ``generate_tiles()`` method of the backend.
"""
assert isinstance(
pipeline, pathml.preprocessing.pipeline.Pipeline
), f"pipeline is of type {type(pipeline)} but must be of type pathml.preprocessing.pipeline.Pipeline"
assert self.slide is not None, "cannot run pipeline because self.slide is None"

if len(self.tiles) != 0:
# in this case, tiles already exist
if not overwrite_existing_tiles:
raise Exception(
f"Slide already has tiles. Running the pipeline will overwrite the existing tiles. Use overwrite_existing_tiles=True to force overwriting existing tiles."
)
else:
# delete all existing tiles
for tile_key in self.tiles.keys:
self.tiles.remove(tile_key)

# TODO: be careful here since we are modifying h5 outside of h5manager
# look into whether we can push this into h5manager

if tile_stride is None:
tile_stride = tile_size
elif isinstance(tile_stride, int):
tile_stride = (tile_stride, tile_stride)

self.h5manager.h5["tiles"].attrs["tile_stride"] = tile_stride

shutdown_after = False

if distributed:
Expand All @@ -308,26 +314,20 @@ def run(
)

# map pipeline application onto each tile
futures = [
client.submit(pipeline.apply, tile)
for tile in self.generate_tiles(
level=level,
shape=tile_size,
stride=tile_stride,
pad=tile_pad,
**kwargs,
)
]
futures = [client.submit(pipeline.apply, tile) for tile in self.tiles]

# After a worker processes a tile, add the tile to h5
for future, result in dask.distributed.as_completed(
futures, with_results=True, raise_errors=False
):
if future.status == "finished":
self.tiles.add(result)
if self._add_tiles:
self.tiles.add(result)
if future.status == "error":
typ, exc, tb = result
if typ is DropTileException:
# TODO: remove tile if it already is in Tiles
# TODO: figure out how to access tile.coords; need to get input that led to exception...
pass
else:
raise exc.with_traceback(tb)
Expand All @@ -341,25 +341,23 @@ def run(
# future.cancel()
# del result
# del future
# del futures
# del futures

if shutdown_after:
client.shutdown()
else:
pass
# Stopgap to free unmanaged memory on client before processing another slide
client.restart()

else:
for tile in self.generate_tiles(
level=level,
shape=tile_size,
stride=tile_stride,
pad=tile_pad,
**kwargs,
):
pipeline.apply(tile)
self.tiles.add(tile)
for tile in self.tiles:
try:
pipeline.apply(tile)
except DropTileException:
if not self._add_tiles:
self.tiles.remove(self.tile.coords)
if self._add_tiles:
self.tiles.add(tile)

if write_dir:
self.write(Path(write_dir) / f"{self.name}.h5path")
Expand Down Expand Up @@ -402,35 +400,21 @@ def extract_region(self, location, size, *args, **kwargs):

return self.slide.extract_region(location, size, *args, **kwargs)

def generate_tiles(self, shape=3000, stride=None, pad=False, **kwargs):
def _generate_tiles(self):
"""
Generator over Tile objects containing regions of the image.
Calls ``generate_tiles()`` method of the backend.
Tries to add the corresponding slide-level masks to each tile, if possible.
Adds slide-level labels to each tile, if possible.

Args:
shape (int or tuple(int)): Size of each tile. May be a tuple of (height, width) or a single integer,
in which case square tiles of that size are generated. Defaults to 256px.
stride (int): stride between chunks. If ``None``, uses ``stride = size`` for non-overlapping chunks.
Defaults to ``None``.
pad (bool): How to handle tiles on the edges. If ``True``, these edge tiles will be zero-padded
and yielded with the other chunks. If ``False``, incomplete edge chunks will be ignored.
Defaults to ``False``.
**kwargs: Other arguments passed through to ``generate_tiles()`` method of the backend.

Yields:
pathml.core.tile.Tile: Extracted Tile object
"""
for tile in self.slide.generate_tiles(shape, stride, pad, **kwargs):
# TODO: move to worker!! (forces loading data on main thread)

for tile in self.slide.generate_tiles(
self.tile_shape, self.tile_stride, self.tile_pad, level=self.level
):
# add masks for tile, if possible
# i.e. if the SlideData has a Masks object, and the tile has coordinates
if self.masks is not None and tile.coords is not None:
# masks not supported if pad=True
# to implement, need to update Mask.slice to support slices that go beyond the full mask
if not pad:
# TODO: update Mask.slice to support slices that go beyond the full mask
if not self.tile_pad:
i, j = tile.coords
# Accessing image loads data on main thread
# dask.delayed waits until compute is called on worker
Expand All @@ -448,8 +432,6 @@ def generate_tiles(self, shape=3000, stride=None, pad=False, **kwargs):
tile_slices = [slice(i, i + di), slice(j, j + dj)]
tile.masks = self.masks.slice(tile_slices)

# TODO: end move to worker

# add slide-level labels to each tile, if possible
if self.labels is not None:
tile.labels = self.labels
Expand Down