Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix DeepMIL for TCGA CRCK dataset (#659)
Browse files Browse the repository at this point in the history
While we updated DeepMIL for the Panda dataset to work with the latest changes, we did not update DeepMIL for the TCGA CRCK dataset.

This PR updates how the caching of the encoded tiles is done and how the checkpoints of the DeepMIL model is saved and loaded.

No additional tests are required since these are the same functions that we use for the Panda dataset. For all of them a test already exists.

Last, the PR updates the cudatoolkit version, Anton and I found that this is the root cause for all our problems with ddp
  • Loading branch information
maxilse committed Feb 16, 2022
1 parent 914a893 commit 1600ef3
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 25 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ created.
## Upcoming

### Added
- ([#649](https://github.com/microsoft/InnerEye-DeepLearning/pull/649)) Fix for the _convert_to_tensor_if_necessary method so that PIL.Image as well as np.array get converted to torch.Tensor.
- ([#643](https://github.com/microsoft/InnerEye-DeepLearning/pull/643)) Test for recovery of SSL job. Tracks learning rate and train
loss.
- ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) When supplying a "--tag" argument, the AzureML jobs use that value as the display name, to more easily distinguish run.
Expand Down Expand Up @@ -51,6 +50,7 @@ jobs that run in AzureML.
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task.

### Changed
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Update cudatoolkit version from 11.1 to 11.3.
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
- ([#585](https://github.com/microsoft/InnerEye-DeepLearning/pull/585)) Switching to PyTorch 1.10.0 and torchvision 0.11.1
- ([#576](https://github.com/microsoft/InnerEye-DeepLearning/pull/576)) The console output is no longer written to stdout.txt because AzureML handles that better now
Expand Down Expand Up @@ -87,6 +87,8 @@ gets uploaded to AzureML, by skipping all test folders.
- ([#632](https://github.com/microsoft/InnerEye-DeepLearning/pull/632)) Nifti test data is no longer stored in Git LFS

### Fixed
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Fix caching and checkpointing for TCGA CRCk dataset.
- ([#649](https://github.com/microsoft/InnerEye-DeepLearning/pull/649)) Fix for the _convert_to_tensor_if_necessary method so that PIL.Image as well as np.array get converted to torch.Tensor.
- ([#606](https://github.com/microsoft/InnerEye-DeepLearning/pull/606)) Bug fix: registered models do not include the hi-ml submodule
- ([#646](https://github.com/microsoft/InnerEye-DeepLearning/pull/646)) Workaround for bug in PL: CombinedLoader cannot be used for training data when using DDP
- ([#593](https://github.com/microsoft/InnerEye-DeepLearning/pull/593)) Bug fix for hi-ml 0.1.11 issue (#130): empty mount point is turned into ".", which fails the AML job
Expand Down
57 changes: 35 additions & 22 deletions InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,50 +13,51 @@
- Schirris (2021). DeepSMILE: Self-supervised heterogeneity-aware multiple instance learning for DNA
damage response defect classification directly from H&E whole-slide images. arXiv:2107.09405
"""
from pathlib import Path
from typing import Any, List
from pathlib import Path
import os

from monai.transforms import Compose
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import Callback

from health_ml.networks.layers.attention_layers import GatedAttentionLayer
from health_azure.utils import get_workspace
from health_azure.utils import CheckpointDownloader
from health_azure.utils import get_workspace
from health_ml.networks.layers.attention_layers import AttentionLayer
from InnerEye.Common import fixed_paths
from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation
from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
from InnerEye.ML.Histopathology.datamodules.tcga_crck_module import (
TcgaCrckTilesDataModule,
from InnerEye.ML.Histopathology.datamodules.tcga_crck_module import TcgaCrckTilesDataModule
from InnerEye.ML.common import get_best_checkpoint_path

from InnerEye.ML.Histopathology.models.transforms import (
EncodeTilesBatchd,
LoadTilesBatchd,
)
from InnerEye.ML.Histopathology.models.encoders import (
HistoSSLEncoder,
ImageNetEncoder,
ImageNetSimCLREncoder,
InnerEyeSSLEncoder,
)
from InnerEye.ML.Histopathology.models.transforms import (
EncodeTilesBatchd,
LoadTilesBatchd,
)
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import (
TcgaCrck_TilesDataset,
)
from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset


class DeepSMILECrck(BaseMIL):
def __init__(self, **kwargs: Any) -> None:
# Define dictionary with default params that can be overriden from subclasses or CLI
default_kwargs = dict(
# declared in BaseMIL:
pooling_type=GatedAttentionLayer.__name__,
pooling_type=AttentionLayer.__name__,
encoding_chunk_size=60,
cache_mode=CacheMode.MEMORY,
precache_location=CacheLocation.CPU,
# declared in DatasetParams:
local_dataset=Path("/tmp/datasets/TCGA-CRCk"),
azure_dataset_id="TCGA-CRCk",
# To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI
# declared in TrainerParams:
num_epochs=16,
num_epochs=50,
# declared in WorkflowParams:
number_of_cross_validation_splits=5,
cross_validation_split_index=0,
Expand Down Expand Up @@ -120,7 +121,7 @@ def get_data_module(self) -> TilesDataModule:
batch_size=self.batch_size,
transform=transform,
cache_mode=self.cache_mode,
save_precache=self.save_precache,
precache_location=self.precache_location,
cache_dir=self.cache_dir,
number_of_cross_validation_splits=self.number_of_cross_validation_splits,
cross_validation_split_index=self.cross_validation_split_index,
Expand All @@ -135,11 +136,23 @@ def get_path_to_best_checkpoint(self) -> Path:
was applied there.
"""
# absolute path is required for registering the model.
return (
fixed_paths.repository_root_directory()
/ self.checkpoint_folder_path
/ self.best_checkpoint_filename_with_suffix
)
absolute_checkpoint_path = Path(fixed_paths.repository_root_directory(),
self.checkpoint_folder_path,
self.best_checkpoint_filename_with_suffix)
if absolute_checkpoint_path.is_file():
return absolute_checkpoint_path

absolute_checkpoint_path_parent = Path(fixed_paths.repository_parent_directory(),
self.checkpoint_folder_path,
self.best_checkpoint_filename_with_suffix)
if absolute_checkpoint_path_parent.is_file():
return absolute_checkpoint_path_parent

checkpoint_path = get_best_checkpoint_path(Path(self.checkpoint_folder_path))
if checkpoint_path.is_file():
return checkpoint_path

raise ValueError("Path to best checkpoint not found")


class TcgaCrckImageNetMIL(DeepSMILECrck):
Expand Down
2 changes: 1 addition & 1 deletion azure-pipelines/build-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ jobs:
- name: tag
value: 'TrainEnsemble'
- name: more_switches
value: '--pl_deterministic --log_level=DEBUG --regression_test_folder=RegressionTestResults/PR_TrainEnsemble'
value: '--pl_deterministic --log_level=DEBUG --regression_test_folder=RegressionTestResults/PR_TrainEnsemble --regression_test_csv_tolerance=1e-5'
pool:
vmImage: 'ubuntu-20.04'
steps:
Expand Down
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- pytorch
- conda-forge
dependencies:
- cudatoolkit=11.1
- cudatoolkit=11.3
- pip=20.1.1
- python=3.7.3
- pytorch=1.10.0
Expand All @@ -28,6 +28,7 @@ dependencies:
- gputil==1.4.0
- h5py==2.10.0
- ipython==7.31.1
- imageio==2.15.0
- InnerEye-DICOM-RT==1.0.1
- joblib==0.16.0
- jupyter==1.0.0
Expand Down

0 comments on commit 1600ef3

Please sign in to comment.