Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core structure updates: image and text featurizer #172

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
Next Next commit
new base and sample dataset
  • Loading branch information
zzachw committed Jun 14, 2023
commit 517f0356dd06e5bcf813b0027272055e2fa836d3
10 changes: 7 additions & 3 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from .base_dataset_v2 import BaseDataset
from .base_ehr_dataset import BaseEHRDataset
from .base_signal_dataset import BaseSignalDataset
from .covid19_cxr import COVID19CXRDataset
from .eicu import eICUDataset
from .isruc import ISRUCDataset
from .medical_transriptions import MedicalTranscriptionsDataset
from .mimic3 import MIMIC3Dataset
from .mimic4 import MIMIC4Dataset
from .mimicextract import MIMICExtractDataset
from .omop import OMOPDataset
from .sleepedf import SleepEDFDataset
from .isruc import ISRUCDataset
from .shhs import SHHSDataset
from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset
from .sample_dataset_v2 import SampleDataset
from .shhs import SHHSDataset
from .sleepedf import SleepEDFDataset
from .splitter import split_by_patient, split_by_visit
from .utils import collate_fn_dict, get_dataloader, strptime
74 changes: 74 additions & 0 deletions pyhealth/datasets/base_dataset_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import logging
from abc import ABC, abstractmethod
from typing import Optional, Dict

from tqdm import tqdm

from pyhealth.datasets.sample_dataset_v2 import SampleDataset
from pyhealth.tasks.task_template import TaskTemplate

logger = logging.getLogger(__name__)


class BaseDataset(ABC):
"""Abstract base dataset class."""

def __init__(
self,
root: str,
dataset_name: Optional[str] = None,
**kwargs,
):
if dataset_name is None:
dataset_name = self.__class__.__name__
self.root = root
self.dataset_name = dataset_name
logger.debug(f"Processing {self.dataset_name} base dataset...")
self.patients = self.process()
# TODO: cache
return

def __str__(self):
return f"Base dataset {self.dataset_name}"

def __len__(self):
return len(self.patients)

@abstractmethod
def process(self) -> Dict:
raise NotImplementedError

@abstractmethod
def stat(self):
print(f"Statistics of {self.dataset_name}:")
return

@property
def default_task(self) -> Optional[TaskTemplate]:
return None

def set_task(self, task: Optional[TaskTemplate] = None) -> SampleDataset:
"""Processes the base dataset to generate the task-specific sample dataset.
"""
# TODO: cache?
if task is None:
# assert default tasks exist in attr
assert self.default_task is not None, "No default tasks found"
task = self.default_task

# load from raw data
logger.debug(f"Setting task for {self.dataset_name} base dataset...")

samples = []
for patient_id, patient in tqdm(
self.patients.items(), desc=f"Generating samples for {task.task_name}"
):
samples.extend(task(patient))
sample_dataset = SampleDataset(
samples,
input_schema=task.input_schema,
output_schema=task.output_schema,
dataset_name=self.dataset_name,
task_name=task,
)
return sample_dataset
2 changes: 2 additions & 0 deletions pyhealth/datasets/featurizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .image import ImageFeaturizer
from .value import ValueFeaturizer
23 changes: 23 additions & 0 deletions pyhealth/datasets/featurizers/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

import PIL.Image
import torchvision.transforms as transforms


class ImageFeaturizer:

def __init__(self):
self.transform = transforms.Compose([transforms.ToTensor()])

def encode(self, value):
image = PIL.Image.open(value)
image.load() # to avoid "Too many open files" errors
image = self.transform(image)
return image


if __name__ == "__main__":
sample_image = "/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset/Normal/images/Normal-6335.png"
featurizer = ImageFeaturizer()
print(featurizer)
print(type(featurizer))
print(featurizer.encode(sample_image))
14 changes: 14 additions & 0 deletions pyhealth/datasets/featurizers/value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass


@dataclass
class ValueFeaturizer:

def encode(self, value):
return value


if __name__ == "__main__":
featurizer = ValueFeaturizer()
print(featurizer)
print(featurizer.encode(2))
175 changes: 175 additions & 0 deletions pyhealth/datasets/sample_dataset_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from typing import Dict, List, Optional

from torch.utils.data import Dataset

from pyhealth.datasets.featurizers import ImageFeaturizer, ValueFeaturizer


class SampleDataset(Dataset):
"""Sample dataset class.
"""

def __init__(
self,
samples: List[Dict],
input_schema: Dict[str, str],
output_schema: Dict[str, str],
dataset_name: Optional[str] = None,
task_name: Optional[str] = None,
):
if dataset_name is None:
dataset_name = ""
if task_name is None:
task_name = ""
self.samples = samples
self.input_schema = input_schema
self.output_schema = output_schema
self.dataset_name = dataset_name
self.task_name = task_name
self.transform = None
# TODO: get rid of input_info
self.input_info: Dict = self.validate()
self.build()

def validate(self):
input_keys = set(self.input_schema.keys())
output_keys = set(self.output_schema.keys())
for s in self.samples:
assert input_keys.issubset(s.keys()), \
"Input schema does not match samples."
assert output_keys.issubset(s.keys()), \
"Output schema does not match samples."
input_info = {}
# get label signal info
input_info["label"] = {"type": str, "dim": 0}
return input_info

def build(self):
for k, v in self.input_schema.items():
if v == "image":
self.input_schema[k] = ImageFeaturizer()
else:
self.input_schema[k] = ValueFeaturizer()
for k, v in self.output_schema.items():
if v == "image":
self.output_schema[k] = ImageFeaturizer()
else:
self.output_schema[k] = ValueFeaturizer()
return

def __getitem__(self, index) -> Dict:
"""Returns a sample by index.

Returns:
Dict, a dict with patient_id, visit_id/record_id, and other task-specific
attributes as key. Conversion to index/tensor will be done
in the model.
"""
out = {}
for k, v in self.samples[index].items():
if k in self.input_schema:
out[k] = self.input_schema[k].encode(v)
elif k in self.output_schema:
out[k] = self.output_schema[k].encode(v)
else:
out[k] = v

if self.transform is not None:
out = self.transform(out)

return out

def set_transform(self, transform):
"""Sets the transform for the dataset.

Args:
transform: a callable transform function.
"""
self.transform = transform
return

def get_all_tokens(
self, key: str, remove_duplicates: bool = True, sort: bool = True
) -> List[str]:
"""Gets all tokens with a specific key in the samples.

Args:
key: the key of the tokens in the samples.
remove_duplicates: whether to remove duplicates. Default is True.
sort: whether to sort the tokens by alphabet order. Default is True.

Returns:
tokens: a list of tokens.
"""
# TODO: get rid of this function
input_type = self.input_info[key]["type"]
input_dim = self.input_info[key]["dim"]
if input_type in [float, int]:
assert input_dim == 0, f"Cannot get tokens for vector with key {key}"

tokens = []
for sample in self.samples:
if input_dim == 0:
# a single value
tokens.append(sample[key])
elif input_dim == 2:
# a list of codes
tokens.extend(sample[key])
elif input_dim == 3:
# a list of list of codes
tokens.extend(flatten_list(sample[key]))
else:
raise NotImplementedError
if remove_duplicates:
tokens = list(set(tokens))
if sort:
tokens.sort()
return tokens

def __str__(self):
"""Prints some information of the dataset."""
return f"Sample dataset {self.dataset_name} {self.task_name}"

def __len__(self):
"""Returns the number of samples in the dataset."""
return len(self.samples)


if __name__ == "__main__":
samples = [
{
"id": "0",
"single_vector": [1, 2, 3],
"list_codes": ["505800458", "50580045810", "50580045811"],
"list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]],
"list_list_codes": [
["A05B", "A05C", "A06A"],
["A11D", "A11E"]
],
"list_list_vectors": [
[[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]],
[[7.7, 8.5, 9.4]],
],
"image": "/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset/Normal/images/Normal-6335.png",
"text": "This is a sample text",
"label": 1,
},
]

dataset = SampleDataset(
samples=samples,
input_schema={
"id": "str",
"single_vector": "vector",
"list_codes": "list",
"list_vectors": "list",
"list_list_codes": "list",
"list_list_vectors": "list",
"image": "image",
"text": "text",
},
output_schema={
"label": "label"
}
)
print(dataset[0])