Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Add Covid configs (#456)
Browse files Browse the repository at this point in the history
This PR adds configs to train Covid detection models from Chest-Xray data. 

Co-authored-by: Shruthi42 <[email protected]>
  • Loading branch information
melanibe and Shruthi42 committed May 20, 2021
1 parent 8bae42e commit 55120d7
Show file tree
Hide file tree
Showing 8 changed files with 628 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ console for easier diagnostics.
Additionally, the `TrainHelloWorldAndHelloContainer` job in the PR build has been split into two jobs, `TrainHelloWorld` and
`TrainHelloContainer`. A pytest marker `after_training_hello_container` has been added to run tests after training is
finished in the `TrainHelloContainer` job.
- ([#456](https://github.com/microsoft/InnerEye-DeepLearning/pull/456)) Adding configs to train Covid detection models.

### Changed

Expand Down
39 changes: 38 additions & 1 deletion InnerEye/ML/SSL/datamodules_and_datasets/cxr_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Tuple

import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Subset
from torchvision.datasets import VisionDataset

from InnerEye.Common.type_annotations import PathOrString
Expand Down Expand Up @@ -175,3 +176,39 @@ def _prepare_dataset(self) -> None:
self.dataset_dataframe.Path = self.dataset_dataframe.Path.apply(lambda x: x[strip_n:])
self.indices = np.arange(len(self.dataset_dataframe))
self.filenames = [self.root / p for p in self.dataset_dataframe.Path.values]


class CovidDataset(InnerEyeCXRDatasetWithReturnIndex):
"""
Dataset class to load CovidDataset dataset as datamodule for monitoring SSL training quality directly on
CovidDataset data.
We use CVX03 against CVX12 as proxy task.
"""

def _prepare_dataset(self) -> None:
self.dataset_dataframe = pd.read_csv(self.root / "dataset.csv")
mapping = {0: 0, 3: 0, 1: 1, 2: 1}
# For monitoring purpose with use binary classification CV03vsCV12
self.dataset_dataframe["final_label"] = self.dataset_dataframe.final_label.apply(lambda x: mapping[x])
self.indices = np.arange(len(self.dataset_dataframe))
self.subject_ids = self.dataset_dataframe.subject.values
self.filenames = [self.root / file for file in self.dataset_dataframe.filepath.values]
self.targets = self.dataset_dataframe.final_label.values.astype(np.int64).reshape(-1)

@property
def num_classes(self) -> int:
return 2

def _split_dataset(self, val_split: float, seed: int) -> Tuple[Subset, Subset]:
"""
Implements val - train split.
:param val_split: proportion to use for validation
:param seed: random seed for splitting
:return: dataset_train, dataset_val
"""
shuffled_subject_ids = np.random.RandomState(seed).permutation(np.unique(self.subject_ids))
n_val = int(len(shuffled_subject_ids) * val_split)
val_subjects, train_subjects = shuffled_subject_ids[:n_val], shuffled_subject_ids[n_val:]
train_ids, val_ids = np.where(np.isin(self.subject_ids, train_subjects))[0], \
np.where(np.isin(self.subject_ids, val_subjects))[0]
return Subset(self, train_ids), Subset(self, val_ids)
30 changes: 19 additions & 11 deletions InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from yacs.config import CfgNode

from InnerEye.ML.SSL.datamodules_and_datasets.cifar_datasets import InnerEyeCIFAR10, InnerEyeCIFAR100
from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import CheXpert, NIHCXR, RSNAKaggleCXR
from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import CheXpert, CovidDataset, NIHCXR, RSNAKaggleCXR
from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule, InnerEyeVisionDataModule
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \
InnerEyeCIFARTrainTransform, \
Expand Down Expand Up @@ -42,11 +42,12 @@ class EncoderName(Enum):


class SSLDatasetName(Enum):
RSNAKaggleCXR = "RSNAKaggleCXR"
NIHCXR = "NIHCXR"
CIFAR10 = "CIFAR10"
CIFAR100 = "CIFAR100"
RSNAKaggleCXR = "RSNAKaggleCXR"
NIHCXR = "NIHCXR"
CheXpert = "CheXpert"
Covid = "CovidDataset"


InnerEyeDataModuleTypes = Union[InnerEyeVisionDataModule, CombinedDataModule]
Expand All @@ -62,11 +63,12 @@ class SSLContainer(LightningContainer):
Note that this container is also used as the base class for SSLImageClassifier (finetuning container) as they share
setup and datamodule methods.
"""
_SSLDataClassMappings = {SSLDatasetName.RSNAKaggleCXR.value: RSNAKaggleCXR,
SSLDatasetName.NIHCXR.value: NIHCXR,
SSLDatasetName.CIFAR10.value: InnerEyeCIFAR10,
_SSLDataClassMappings = {SSLDatasetName.CIFAR10.value: InnerEyeCIFAR10,
SSLDatasetName.CIFAR100.value: InnerEyeCIFAR100,
SSLDatasetName.CheXpert.value: CheXpert}
SSLDatasetName.RSNAKaggleCXR.value: RSNAKaggleCXR,
SSLDatasetName.NIHCXR.value: NIHCXR,
SSLDatasetName.CheXpert.value: CheXpert,
SSLDatasetName.Covid.value: CovidDataset}

ssl_augmentation_config = param.ClassSelector(class_=Path, allow_None=True,
doc="The path to the yaml config defining the parameters of the "
Expand All @@ -87,11 +89,13 @@ class SSLContainer(LightningContainer):
"Used for debugging and tests.")
linear_head_augmentation_config = param.ClassSelector(class_=Path,
doc="The path to the yaml config for the linear head "
"augmentations")
"augmentations")
linear_head_dataset_name = param.ClassSelector(class_=SSLDatasetName,
doc="Name of the dataset to use for the linear head training")
linear_head_batch_size = param.Integer(default=256, doc="Batch size for linear head tuning")
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4, doc="Learning rate for linear head training during SSL training.")
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4,
doc="Learning rate for linear head training during "
"SSL training.")

def setup(self) -> None:
from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer
Expand Down Expand Up @@ -173,7 +177,8 @@ def get_data_module(self) -> InnerEyeDataModuleTypes:
return self.data_module
encoder_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=True)
linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False)
return CombinedDataModule(encoder_data_module, linear_data_module, self.use_balanced_binary_loss_for_linear_head)
return CombinedDataModule(encoder_data_module, linear_data_module,
self.use_balanced_binary_loss_for_linear_head)

def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisionDataModule:
"""
Expand Down Expand Up @@ -220,7 +225,10 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode],
applied on. If False, return only one transformation.
:return: training transformation pipeline and validation transformation pipeline.
"""
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value, SSLDatasetName.NIHCXR.value, SSLDatasetName.CheXpert.value]:
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value,
SSLDatasetName.NIHCXR.value,
SSLDatasetName.CheXpert.value,
SSLDatasetName.Covid.value]:
assert augmentation_config is not None
train_transforms, val_transforms = get_cxr_ssl_transforms(augmentation_config,
return_two_views_per_sample=is_ssl_encoder_module,
Expand Down
Loading

0 comments on commit 55120d7

Please sign in to comment.