From 517f0356dd06e5bcf813b0027272055e2fa836d3 Mon Sep 17 00:00:00 2001 From: zzachw Date: Wed, 14 Jun 2023 17:01:24 -0400 Subject: [PATCH 01/16] new base and sample dataset --- pyhealth/datasets/__init__.py | 10 +- pyhealth/datasets/base_dataset_v2.py | 74 +++++++++ pyhealth/datasets/featurizers/__init__.py | 2 + pyhealth/datasets/featurizers/image.py | 23 +++ pyhealth/datasets/featurizers/value.py | 14 ++ pyhealth/datasets/sample_dataset_v2.py | 175 ++++++++++++++++++++++ 6 files changed, 295 insertions(+), 3 deletions(-) create mode 100644 pyhealth/datasets/base_dataset_v2.py create mode 100644 pyhealth/datasets/featurizers/__init__.py create mode 100644 pyhealth/datasets/featurizers/image.py create mode 100644 pyhealth/datasets/featurizers/value.py create mode 100644 pyhealth/datasets/sample_dataset_v2.py diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index bd5c530e..0c0a9565 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -1,13 +1,17 @@ +from .base_dataset_v2 import BaseDataset from .base_ehr_dataset import BaseEHRDataset from .base_signal_dataset import BaseSignalDataset +from .covid19_cxr import COVID19CXRDataset from .eicu import eICUDataset +from .isruc import ISRUCDataset +from .medical_transriptions import MedicalTranscriptionsDataset from .mimic3 import MIMIC3Dataset from .mimic4 import MIMIC4Dataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset -from .sleepedf import SleepEDFDataset -from .isruc import ISRUCDataset -from .shhs import SHHSDataset from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset +from .sample_dataset_v2 import SampleDataset +from .shhs import SHHSDataset +from .sleepedf import SleepEDFDataset from .splitter import split_by_patient, split_by_visit from .utils import collate_fn_dict, get_dataloader, strptime diff --git a/pyhealth/datasets/base_dataset_v2.py b/pyhealth/datasets/base_dataset_v2.py new file mode 100644 index 00000000..5cbf007e --- /dev/null +++ b/pyhealth/datasets/base_dataset_v2.py @@ -0,0 +1,74 @@ +import logging +from abc import ABC, abstractmethod +from typing import Optional, Dict + +from tqdm import tqdm + +from pyhealth.datasets.sample_dataset_v2 import SampleDataset +from pyhealth.tasks.task_template import TaskTemplate + +logger = logging.getLogger(__name__) + + +class BaseDataset(ABC): + """Abstract base dataset class.""" + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + **kwargs, + ): + if dataset_name is None: + dataset_name = self.__class__.__name__ + self.root = root + self.dataset_name = dataset_name + logger.debug(f"Processing {self.dataset_name} base dataset...") + self.patients = self.process() + # TODO: cache + return + + def __str__(self): + return f"Base dataset {self.dataset_name}" + + def __len__(self): + return len(self.patients) + + @abstractmethod + def process(self) -> Dict: + raise NotImplementedError + + @abstractmethod + def stat(self): + print(f"Statistics of {self.dataset_name}:") + return + + @property + def default_task(self) -> Optional[TaskTemplate]: + return None + + def set_task(self, task: Optional[TaskTemplate] = None) -> SampleDataset: + """Processes the base dataset to generate the task-specific sample dataset. + """ + # TODO: cache? + if task is None: + # assert default tasks exist in attr + assert self.default_task is not None, "No default tasks found" + task = self.default_task + + # load from raw data + logger.debug(f"Setting task for {self.dataset_name} base dataset...") + + samples = [] + for patient_id, patient in tqdm( + self.patients.items(), desc=f"Generating samples for {task.task_name}" + ): + samples.extend(task(patient)) + sample_dataset = SampleDataset( + samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name=self.dataset_name, + task_name=task, + ) + return sample_dataset diff --git a/pyhealth/datasets/featurizers/__init__.py b/pyhealth/datasets/featurizers/__init__.py new file mode 100644 index 00000000..7ad6268b --- /dev/null +++ b/pyhealth/datasets/featurizers/__init__.py @@ -0,0 +1,2 @@ +from .image import ImageFeaturizer +from .value import ValueFeaturizer \ No newline at end of file diff --git a/pyhealth/datasets/featurizers/image.py b/pyhealth/datasets/featurizers/image.py new file mode 100644 index 00000000..d867f8ef --- /dev/null +++ b/pyhealth/datasets/featurizers/image.py @@ -0,0 +1,23 @@ + +import PIL.Image +import torchvision.transforms as transforms + + +class ImageFeaturizer: + + def __init__(self): + self.transform = transforms.Compose([transforms.ToTensor()]) + + def encode(self, value): + image = PIL.Image.open(value) + image.load() # to avoid "Too many open files" errors + image = self.transform(image) + return image + + +if __name__ == "__main__": + sample_image = "/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset/Normal/images/Normal-6335.png" + featurizer = ImageFeaturizer() + print(featurizer) + print(type(featurizer)) + print(featurizer.encode(sample_image)) diff --git a/pyhealth/datasets/featurizers/value.py b/pyhealth/datasets/featurizers/value.py new file mode 100644 index 00000000..2bd4c584 --- /dev/null +++ b/pyhealth/datasets/featurizers/value.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + + +@dataclass +class ValueFeaturizer: + + def encode(self, value): + return value + + +if __name__ == "__main__": + featurizer = ValueFeaturizer() + print(featurizer) + print(featurizer.encode(2)) diff --git a/pyhealth/datasets/sample_dataset_v2.py b/pyhealth/datasets/sample_dataset_v2.py new file mode 100644 index 00000000..cd066682 --- /dev/null +++ b/pyhealth/datasets/sample_dataset_v2.py @@ -0,0 +1,175 @@ +from typing import Dict, List, Optional + +from torch.utils.data import Dataset + +from pyhealth.datasets.featurizers import ImageFeaturizer, ValueFeaturizer + + +class SampleDataset(Dataset): + """Sample dataset class. + """ + + def __init__( + self, + samples: List[Dict], + input_schema: Dict[str, str], + output_schema: Dict[str, str], + dataset_name: Optional[str] = None, + task_name: Optional[str] = None, + ): + if dataset_name is None: + dataset_name = "" + if task_name is None: + task_name = "" + self.samples = samples + self.input_schema = input_schema + self.output_schema = output_schema + self.dataset_name = dataset_name + self.task_name = task_name + self.transform = None + # TODO: get rid of input_info + self.input_info: Dict = self.validate() + self.build() + + def validate(self): + input_keys = set(self.input_schema.keys()) + output_keys = set(self.output_schema.keys()) + for s in self.samples: + assert input_keys.issubset(s.keys()), \ + "Input schema does not match samples." + assert output_keys.issubset(s.keys()), \ + "Output schema does not match samples." + input_info = {} + # get label signal info + input_info["label"] = {"type": str, "dim": 0} + return input_info + + def build(self): + for k, v in self.input_schema.items(): + if v == "image": + self.input_schema[k] = ImageFeaturizer() + else: + self.input_schema[k] = ValueFeaturizer() + for k, v in self.output_schema.items(): + if v == "image": + self.output_schema[k] = ImageFeaturizer() + else: + self.output_schema[k] = ValueFeaturizer() + return + + def __getitem__(self, index) -> Dict: + """Returns a sample by index. + + Returns: + Dict, a dict with patient_id, visit_id/record_id, and other task-specific + attributes as key. Conversion to index/tensor will be done + in the model. + """ + out = {} + for k, v in self.samples[index].items(): + if k in self.input_schema: + out[k] = self.input_schema[k].encode(v) + elif k in self.output_schema: + out[k] = self.output_schema[k].encode(v) + else: + out[k] = v + + if self.transform is not None: + out = self.transform(out) + + return out + + def set_transform(self, transform): + """Sets the transform for the dataset. + + Args: + transform: a callable transform function. + """ + self.transform = transform + return + + def get_all_tokens( + self, key: str, remove_duplicates: bool = True, sort: bool = True + ) -> List[str]: + """Gets all tokens with a specific key in the samples. + + Args: + key: the key of the tokens in the samples. + remove_duplicates: whether to remove duplicates. Default is True. + sort: whether to sort the tokens by alphabet order. Default is True. + + Returns: + tokens: a list of tokens. + """ + # TODO: get rid of this function + input_type = self.input_info[key]["type"] + input_dim = self.input_info[key]["dim"] + if input_type in [float, int]: + assert input_dim == 0, f"Cannot get tokens for vector with key {key}" + + tokens = [] + for sample in self.samples: + if input_dim == 0: + # a single value + tokens.append(sample[key]) + elif input_dim == 2: + # a list of codes + tokens.extend(sample[key]) + elif input_dim == 3: + # a list of list of codes + tokens.extend(flatten_list(sample[key])) + else: + raise NotImplementedError + if remove_duplicates: + tokens = list(set(tokens)) + if sort: + tokens.sort() + return tokens + + def __str__(self): + """Prints some information of the dataset.""" + return f"Sample dataset {self.dataset_name} {self.task_name}" + + def __len__(self): + """Returns the number of samples in the dataset.""" + return len(self.samples) + + +if __name__ == "__main__": + samples = [ + { + "id": "0", + "single_vector": [1, 2, 3], + "list_codes": ["505800458", "50580045810", "50580045811"], + "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], + "list_list_codes": [ + ["A05B", "A05C", "A06A"], + ["A11D", "A11E"] + ], + "list_list_vectors": [ + [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], + [[7.7, 8.5, 9.4]], + ], + "image": "/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset/Normal/images/Normal-6335.png", + "text": "This is a sample text", + "label": 1, + }, + ] + + dataset = SampleDataset( + samples=samples, + input_schema={ + "id": "str", + "single_vector": "vector", + "list_codes": "list", + "list_vectors": "list", + "list_list_codes": "list", + "list_list_vectors": "list", + "image": "image", + "text": "text", + }, + output_schema={ + "label": "label" + } + ) + print(dataset[0]) From ba0898b5606882e8284fe8c7bafe27d621fb4929 Mon Sep 17 00:00:00 2001 From: zzachw Date: Wed, 14 Jun 2023 17:01:47 -0400 Subject: [PATCH 02/16] task template --- pyhealth/tasks/__init__.py | 3 +++ pyhealth/tasks/task_template.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 pyhealth/tasks/task_template.py diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 70c2b848..d3c38e4e 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -1,3 +1,4 @@ +from .task_template import TaskTemplate from .drug_recommendation import ( drug_recommendation_eicu_fn, drug_recommendation_mimic3_fn, @@ -29,3 +30,5 @@ sleep_staging_isruc_fn, sleep_staging_shhs_fn, ) +from .covid19_cxr_classification import COVID19CXRClassification +from .medical_transcriptions_classification import MedicalTranscriptionsClassification diff --git a/pyhealth/tasks/task_template.py b/pyhealth/tasks/task_template.py new file mode 100644 index 00000000..0415b5b0 --- /dev/null +++ b/pyhealth/tasks/task_template.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, List + + +@dataclass(frozen=True) +class TaskTemplate(ABC): + task_name: str + input_schema: Dict[str, str] + output_schema: Dict[str, str] + + @abstractmethod + def __call__(self, patient) -> List[Dict]: + raise NotImplementedError From ba27b77a68b7271cfacd3139631ad955e212bce3 Mon Sep 17 00:00:00 2001 From: zzachw Date: Wed, 14 Jun 2023 17:01:59 -0400 Subject: [PATCH 03/16] add covid19_cxr dataset and task --- pyhealth/datasets/covid19_cxr.py | 144 +++++++++++++++++++ pyhealth/tasks/covid19_cxr_classification.py | 20 +++ 2 files changed, 164 insertions(+) create mode 100644 pyhealth/datasets/covid19_cxr.py create mode 100644 pyhealth/tasks/covid19_cxr_classification.py diff --git a/pyhealth/datasets/covid19_cxr.py b/pyhealth/datasets/covid19_cxr.py new file mode 100644 index 00000000..74e85d39 --- /dev/null +++ b/pyhealth/datasets/covid19_cxr.py @@ -0,0 +1,144 @@ +import os +from collections import Counter + +import pandas as pd + +from pyhealth.datasets.base_dataset_v2 import BaseDataset +from pyhealth.tasks.covid19_cxr_classification import COVID19CXRClassification + + +class COVID19CXRDataset(BaseDataset): + """Base image dataset for COVID-19 Radiography Database + + Dataset is available at https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database + + **COVID-19 data: + ----------------------- + COVID data are collected from different publicly accessible dataset, online sources and published papers. + -2473 CXR images are collected from padchest dataset[1]. + -183 CXR images from a Germany medical school[2]. + -559 CXR image from SIRM, Github, Kaggle & Tweeter[3,4,5,6] + -400 CXR images from another Github source[7]. + + ***Normal images: + ---------------------------------------- + 10192 Normal data are collected from from three different dataset. + -8851 RSNA [8] + -1341 Kaggle [9] + + ***Lung opacity images: + ---------------------------------------- + 6012 Lung opacity CXR images are collected from Radiological Society of North America (RSNA) CXR dataset [8] + + ***Viral Pneumonia images: + ---------------------------------------- + 1345 Viral Pneumonia data are collected from the Chest X-Ray Images (pneumonia) database [9] + + Please cite the follwoing two articles if you are using this dataset: + -M.E.H. Chowdhury, T. Rahman, A. Khandakar, R. Mazhar, M.A. Kadir, Z.B. Mahbub, K.R. Islam, M.S. Khan, A. Iqbal, N. Al-Emadi, M.B.I. Reaz, M. T. Islam, “Can AI help in screening Viral and COVID-19 pneumonia?” IEEE Access, Vol. 8, 2020, pp. 132665 - 132676. + -Rahman, T., Khandakar, A., Qiblawey, Y., Tahir, A., Kiranyaz, S., Kashem, S.B.A., Islam, M.T., Maadeed, S.A., Zughaier, S.M., Khan, M.S. and Chowdhury, M.E., 2020. Exploring the Effect of Image Enhancement Techniques on COVID-19 Detection using Chest X-ray Images. arXiv preprint arXiv:2012.02238. + + **Reference: + [1] https://bimcv.cipf.es/bimcv-projects/bimcv-covid19/#1590858128006-9e640421-6711 + [2] https://github.com/ml-workgroup/covid-19-image-repository/tree/master/png + [3] https://sirm.org/category/senza-categoria/covid-19/ + [4] https://eurorad.org + [5] https://github.com/ieee8023/covid-chestxray-dataset + [6] https://figshare.com/articles/COVID-19_Chest_X-Ray_Image_Repository/12580328 + [7] https://github.com/armiro/COVID-CXNet + [8] https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data + [9] https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia + + Args: + dataset_name: name of the dataset. + root: root directory of the raw data. *You can choose to use the path to Cassette portion or the Telemetry portion.* + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. Default is False. + + Attributes: + root: root directory of the raw data (should contain many csv files). + dataset_name: name of the dataset. Default is the name of the class. + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. Default is False. + + Examples: + >>> dataset = COVID19CXRDataset( + root="/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset", + ) + >>> print(dataset[0]) + >>> dataset.stat() + >>> dataset.info() + """ + + def process(self): + # process and merge raw xlsx files from the dataset + covid = pd.DataFrame( + pd.read_excel(f"{self.root}/COVID.metadata.xlsx") + ) + covid["FILE NAME"] = covid["FILE NAME"].apply( + lambda x: f"{self.root}/COVID/images/{x}.png" + ) + covid["label"] = "COVID" + lung_opacity = pd.DataFrame( + pd.read_excel(f"{self.root}/Lung_Opacity.metadata.xlsx") + ) + lung_opacity["FILE NAME"] = lung_opacity["FILE NAME"].apply( + lambda x: f"{self.root}/Lung_Opacity/images/{x}.png" + ) + lung_opacity["label"] = "Lung Opacity" + normal = pd.DataFrame( + pd.read_excel(f"{self.root}/Normal.metadata.xlsx") + ) + normal["FILE NAME"] = normal["FILE NAME"].apply( + lambda x: x.capitalize() + ) + normal["FILE NAME"] = normal["FILE NAME"].apply( + lambda x: f"{self.root}/Normal/images/{x}.png" + ) + normal["label"] = "Normal" + viral_pneumonia = pd.DataFrame( + pd.read_excel(f"{self.root}/Viral Pneumonia.metadata.xlsx") + ) + viral_pneumonia["FILE NAME"] = viral_pneumonia["FILE NAME"].apply( + lambda x: f"{self.root}/Viral Pneumonia/images/{x}.png" + ) + viral_pneumonia["label"] = "Viral Pneumonia" + df = pd.concat( + [covid, lung_opacity, normal, viral_pneumonia], + axis=0, + ignore_index=True + ) + df = df.drop(columns=["FORMAT", "SIZE"]) + df.columns = ["path", "url", "label"] + for path in df.path: + assert os.path.isfile(os.path.join(self.root, path)) + # create patient dict + patients = {} + for index, row in df.iterrows(): + patients[index] = row.to_dict() + return patients + + def stat(self): + super().stat() + print(f"Number of samples: {len(self.patients)}") + count = Counter([v['label'] for v in self.patients.values()]) + print(f"Number of classes: {len(count)}") + print(f"Class distribution: {count}") + + @property + def default_task(self): + return COVID19CXRClassification() + + +if __name__ == "__main__": + dataset = COVID19CXRDataset( + root="/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset", + ) + print(list(dataset.patients.items())[0]) + dataset.stat() + samples = dataset.set_task() + print(samples[0]) diff --git a/pyhealth/tasks/covid19_cxr_classification.py b/pyhealth/tasks/covid19_cxr_classification.py new file mode 100644 index 00000000..25c494f7 --- /dev/null +++ b/pyhealth/tasks/covid19_cxr_classification.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass, field +from typing import Dict + +from pyhealth.tasks import TaskTemplate + + +@dataclass(frozen=True) +class COVID19CXRClassification(TaskTemplate): + task_name: str = "COVID19CXRClassification" + input_schema: Dict[str, str] = field(default_factory=lambda: {"path": "image"}) + output_schema: Dict[str, str] = field(default_factory=lambda: {"label": "label"}) + + def __call__(self, patient): + return [patient] + + +if __name__ == "__main__": + task = COVID19CXRClassification() + print(task) + print(type(task)) From 831c495190790d11b197a749194f604e581c487a Mon Sep 17 00:00:00 2001 From: zzachw Date: Wed, 14 Jun 2023 17:02:03 -0400 Subject: [PATCH 04/16] add resnet model --- pyhealth/models/resnet.py | 109 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 pyhealth/models/resnet.py diff --git a/pyhealth/models/resnet.py b/pyhealth/models/resnet.py new file mode 100644 index 00000000..6b669ec1 --- /dev/null +++ b/pyhealth/models/resnet.py @@ -0,0 +1,109 @@ +from typing import List, Dict + +import torch +import torch.nn as nn +from torchvision import models + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + + +class ResNet(BaseModel): + """ResNet model for image data. + + Paper: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. + Deep Residual Learning for Image Recognition. CVPR 2016. + + Args: + dataset: the dataset to train the model. It is used to query certain + information such as the set of all tokens. + feature_keys: list of keys in samples to use as features, + e.g. ["conditions", "procedures"]. + label_key: key in samples to use as label (e.g., "drugs"). + mode: one of "binary", "multiclass", or "multilabel". + pretrained: whether to use pretrained weights. Default is False. + num_layers: number of resnet layers. Supported values are 18, 34, 50, 101, 152. + Default is 18. + """ + + def __init__( + self, + dataset: SampleDataset, + feature_keys: List[str], + label_key: str, + mode: str, + pretrained=False, + num_layers=18, + **kwargs, + ): + super(ResNet, self).__init__( + dataset=dataset, + feature_keys=feature_keys, + label_key=label_key, + mode=mode, + ) + cnn_name = f"resnet{num_layers}" + self.cnn = models.__dict__[cnn_name](pretrained=pretrained) + hidden_dim = self.cnn.fc.in_features + self.label_tokenizer = self.get_label_tokenizer() + output_size = self.get_output_size(self.label_tokenizer) + self.cnn.fc = nn.Linear(hidden_dim, output_size) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation.""" + # concat the info within one batch (batch, channel, length) + x = kwargs[self.feature_keys[0]] + x = torch.stack(x, dim=0).to(self.device) + if x.shape[1] == 1: + x = x.repeat((1, 3, 1, 1)) + logits = self.cnn(x) + y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + } + + +if __name__ == "__main__": + from pyhealth.datasets import COVID19CXRDataset, get_dataloader + from torchvision import transforms + + base_dataset = COVID19CXRDataset( + root="/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset", + ) + + sample_dataset = base_dataset.set_task() + + transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.Normalize(mean=[0.5862785803043838], std=[0.27950088968644304]) + ]) + def encode(sample): + sample["path"] = transform(sample["path"]) + return sample + + sample_dataset.set_transform(encode) + + train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True) + + model = ResNet( + dataset=sample_dataset, + feature_keys=[ + "path", + ], + label_key="label", + mode="multiclass", + ) + + # data batch + data_batch = next(iter(train_loader)) + + # try the model + ret = model(**data_batch) + print(ret) + + # try loss backward + ret["loss"].backward() \ No newline at end of file From 2e2550dc0580ce8eb8809fb96dbbcbdd1d29a2b1 Mon Sep 17 00:00:00 2001 From: zzachw Date: Wed, 14 Jun 2023 17:02:17 -0400 Subject: [PATCH 05/16] add medical transcriptions dataset and task --- pyhealth/datasets/medical_transriptions.py | 68 +++++++++++++++++++ .../medical_transcriptions_classification.py | 26 +++++++ 2 files changed, 94 insertions(+) create mode 100644 pyhealth/datasets/medical_transriptions.py create mode 100644 pyhealth/tasks/medical_transcriptions_classification.py diff --git a/pyhealth/datasets/medical_transriptions.py b/pyhealth/datasets/medical_transriptions.py new file mode 100644 index 00000000..c809533b --- /dev/null +++ b/pyhealth/datasets/medical_transriptions.py @@ -0,0 +1,68 @@ +import os +from collections import Counter + +import pandas as pd + +from pyhealth.datasets.base_dataset_v2 import BaseDataset +from pyhealth.tasks.medical_transcriptions_classification import MedicalTranscriptionsClassification + + +class MedicalTranscriptionsDataset(BaseDataset): + """Medical transcription data scraped from mtsamples.com + + Dataset is available at https://www.kaggle.com/datasets/tboyle10/medicaltranscriptions + + Args: + dataset_name: name of the dataset. + root: root directory of the raw data. *You can choose to use the path to Cassette portion or the Telemetry portion.* + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. Default is False. + + Attributes: + root: root directory of the raw data (should contain many csv files). + dataset_name: name of the dataset. Default is the name of the class. + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. Default is False. + + Examples: + >>> dataset = MedicalTranscriptionsDataset( + root="/srv/local/data/zw12/raw_data/MedicalTranscriptions", + ) + >>> print(dataset[0]) + >>> dataset.stat() + >>> dataset.info() + """ + + def process(self): + df = pd.read_csv(f"{self.root}/mtsamples.csv", index_col=0) + + # create patient dict + patients = {} + for index, row in df.iterrows(): + patients[index] = row.to_dict() + return patients + + def stat(self): + super().stat() + print(f"Number of samples: {len(self.patients)}") + count = Counter([v['medical_specialty'] for v in self.patients.values()]) + print(f"Number of classes: {len(count)}") + print(f"Class distribution: {count}") + + @property + def default_task(self): + return MedicalTranscriptionsClassification() + + +if __name__ == "__main__": + dataset = MedicalTranscriptionsDataset( + root="/srv/local/data/zw12/raw_data/MedicalTranscriptions", + ) + print(list(dataset.patients.items())[0]) + dataset.stat() + samples = dataset.set_task() + print(samples[0]) diff --git a/pyhealth/tasks/medical_transcriptions_classification.py b/pyhealth/tasks/medical_transcriptions_classification.py new file mode 100644 index 00000000..8b87d218 --- /dev/null +++ b/pyhealth/tasks/medical_transcriptions_classification.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass, field +from typing import Dict + +from pyhealth.tasks import TaskTemplate + + +@dataclass(frozen=True) +class MedicalTranscriptionsClassification(TaskTemplate): + task_name: str = "MedicalTranscriptionsClassification" + input_schema: Dict[str, str] = field(default_factory=lambda: {"transcription": "text"}) + output_schema: Dict[str, str] = field(default_factory=lambda: {"label": "label"}) + + def __call__(self, patient): + if patient["transcription"] is None or patient["medical_specialty"] is None: + return [] + sample = { + "transcription": patient["transcription"], + "label": patient["medical_specialty"], + } + return [sample] + + +if __name__ == "__main__": + task = MedicalTranscriptionsClassification() + print(task) + print(type(task)) From c952c591c57c0770de05923c89075bcf611e8a6d Mon Sep 17 00:00:00 2001 From: zzachw Date: Wed, 14 Jun 2023 17:02:25 -0400 Subject: [PATCH 06/16] add connection to huggingface --- pyhealth/models/huggingface.py | 90 ++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 pyhealth/models/huggingface.py diff --git a/pyhealth/models/huggingface.py b/pyhealth/models/huggingface.py new file mode 100644 index 00000000..d5b4c9f7 --- /dev/null +++ b/pyhealth/models/huggingface.py @@ -0,0 +1,90 @@ +from typing import List, Dict + +import torch +from transformers import AutoModel, AutoTokenizer +import torch.nn as nn +from torchvision import models + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + + +class HuggingfaceAutoModel(BaseModel): + """AutoModel class for Huggingface models. + """ + + def __init__( + self, + model_name: str, + dataset: SampleDataset, + feature_keys: List[str], + label_key: str, + mode: str, + pretrained=False, + num_layers=18, + **kwargs, + ): + super(HuggingfaceAutoModel, self).__init__( + dataset=dataset, + feature_keys=feature_keys, + label_key=label_key, + mode=mode, + ) + self.model_name = model_name + self.model = AutoModel.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.label_tokenizer = self.get_label_tokenizer() + output_size = self.get_output_size(self.label_tokenizer) + hidden_dim = self.model.config.hidden_size + self.fc = nn.Linear(hidden_dim, output_size) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation.""" + # concat the info within one batch (batch, channel, length) + x = kwargs[self.feature_keys[0]] + x = self.tokenizer( + x, return_tensors="pt", padding=True, truncation=True, max_length=256 + ) + embeddings = self.model(**x).pooler_output + logits = self.fc(embeddings) + y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + } + + +if __name__ == "__main__": + from pyhealth.datasets import MedicalTranscriptionsDataset, get_dataloader + from torchvision import transforms + + base_dataset = MedicalTranscriptionsDataset( + root="/srv/local/data/zw12/raw_data/MedicalTranscriptions" + ) + + sample_dataset = base_dataset.set_task() + + train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True) + + model = HuggingfaceAutoModel( + model_name="emilyalsentzer/Bio_ClinicalBERT", + dataset=sample_dataset, + feature_keys=[ + "transcription", + ], + label_key="label", + mode="multiclass", + ) + + # data batch + data_batch = next(iter(train_loader)) + + # try the model + ret = model(**data_batch) + print(ret) + + # try loss backward + ret["loss"].backward() \ No newline at end of file From 8f13af145d8b76ce324bf82d43478c028b102621 Mon Sep 17 00:00:00 2001 From: zzachw Date: Sat, 17 Jun 2023 18:39:14 -0400 Subject: [PATCH 07/16] add resnet to models.init --- pyhealth/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index b0467853..2c1c7abb 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -23,3 +23,4 @@ from .stagenet import StageNet, StageNetLayer from .tcn import TCN, TCNLayer from .molerec import MoleRec, MoleRecLayer +from .resnet import ResNet From f29022edac7b88e5c0e473b4693f28f77ea5a66f Mon Sep 17 00:00:00 2001 From: zzachw Date: Sat, 17 Jun 2023 19:26:16 -0400 Subject: [PATCH 08/16] add huggingfaceautomodel to models.init --- pyhealth/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 2c1c7abb..e61ee653 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -24,3 +24,4 @@ from .tcn import TCN, TCNLayer from .molerec import MoleRec, MoleRecLayer from .resnet import ResNet +from .huggingface import HuggingfaceAutoModel From f4387540cee2a2b57b1a4eee875ed1edc95d0566 Mon Sep 17 00:00:00 2001 From: zzachw Date: Sat, 17 Jun 2023 19:43:12 -0400 Subject: [PATCH 09/16] nan bug fixed for MedicalTranscriptionsDataset --- .../medical_transcriptions_classification.py | 49 +++++++++++++++++++ pyhealth/models/huggingface.py | 1 + .../medical_transcriptions_classification.py | 5 +- 3 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 examples/medical_transcriptions_classification.py diff --git a/examples/medical_transcriptions_classification.py b/examples/medical_transcriptions_classification.py new file mode 100644 index 00000000..fb1fe324 --- /dev/null +++ b/examples/medical_transcriptions_classification.py @@ -0,0 +1,49 @@ +import numpy as np +import torch + +from pyhealth.datasets import MedicalTranscriptionsDataset +from pyhealth.datasets import get_dataloader +from pyhealth.models import HuggingfaceAutoModel +from pyhealth.trainer import Trainer + +root = "/srv/local/data/zw12/raw_data/MedicalTranscriptions" +base_dataset = MedicalTranscriptionsDataset(root) + +sample_dataset = base_dataset.set_task() + +ratios = [0.7, 0.1, 0.2] +index = np.arange(len(sample_dataset)) +np.random.shuffle(index) +s1 = int(len(sample_dataset) * ratios[0]) +s2 = int(len(sample_dataset) * (ratios[0] + ratios[1])) +train_index = index[: s1] +val_index = index[s1: s2] +test_index = index[s2:] +train_dataset = torch.utils.data.Subset(sample_dataset, train_index) +val_dataset = torch.utils.data.Subset(sample_dataset, val_index) +test_dataset = torch.utils.data.Subset(sample_dataset, test_index) + +train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) +val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) +test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) + +model = HuggingfaceAutoModel( + model_name="emilyalsentzer/Bio_ClinicalBERT", + dataset=sample_dataset, + feature_keys=["transcription"], + label_key="label", + mode="multiclass", +) + +trainer = Trainer(model=model) + +print(trainer.evaluate(test_dataloader)) + +trainer.train( + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + epochs=1, + monitor="accuracy" +) + +print(trainer.evaluate(test_dataloader)) diff --git a/pyhealth/models/huggingface.py b/pyhealth/models/huggingface.py index d5b4c9f7..d810f854 100644 --- a/pyhealth/models/huggingface.py +++ b/pyhealth/models/huggingface.py @@ -45,6 +45,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: x = self.tokenizer( x, return_tensors="pt", padding=True, truncation=True, max_length=256 ) + x = x.to(self.device) embeddings = self.model(**x).pooler_output logits = self.fc(embeddings) y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) diff --git a/pyhealth/tasks/medical_transcriptions_classification.py b/pyhealth/tasks/medical_transcriptions_classification.py index 8b87d218..2f7fa5b3 100644 --- a/pyhealth/tasks/medical_transcriptions_classification.py +++ b/pyhealth/tasks/medical_transcriptions_classification.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field from typing import Dict +import pandas as pd from pyhealth.tasks import TaskTemplate @@ -11,7 +12,9 @@ class MedicalTranscriptionsClassification(TaskTemplate): output_schema: Dict[str, str] = field(default_factory=lambda: {"label": "label"}) def __call__(self, patient): - if patient["transcription"] is None or patient["medical_specialty"] is None: + if patient["transcription"] is None or pd.isna(patient["transcription"]): + return [] + if patient["medical_specialty"] is None or pd.isna(patient["medical_specialty"]): return [] sample = { "transcription": patient["transcription"], From 8e88da350283233239698746ba3a2d11c466dba6 Mon Sep 17 00:00:00 2001 From: Zijian Wu Date: Thu, 22 Jun 2023 06:19:13 +0000 Subject: [PATCH 10/16] add torchvision_classif.py --- pyhealth/models/resnet.py | 109 ------------- pyhealth/models/torchvision_classif.py | 214 +++++++++++++++++++++++++ 2 files changed, 214 insertions(+), 109 deletions(-) delete mode 100644 pyhealth/models/resnet.py create mode 100644 pyhealth/models/torchvision_classif.py diff --git a/pyhealth/models/resnet.py b/pyhealth/models/resnet.py deleted file mode 100644 index 6b669ec1..00000000 --- a/pyhealth/models/resnet.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import List, Dict - -import torch -import torch.nn as nn -from torchvision import models - -from pyhealth.datasets import SampleDataset -from pyhealth.models import BaseModel - - -class ResNet(BaseModel): - """ResNet model for image data. - - Paper: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. - Deep Residual Learning for Image Recognition. CVPR 2016. - - Args: - dataset: the dataset to train the model. It is used to query certain - information such as the set of all tokens. - feature_keys: list of keys in samples to use as features, - e.g. ["conditions", "procedures"]. - label_key: key in samples to use as label (e.g., "drugs"). - mode: one of "binary", "multiclass", or "multilabel". - pretrained: whether to use pretrained weights. Default is False. - num_layers: number of resnet layers. Supported values are 18, 34, 50, 101, 152. - Default is 18. - """ - - def __init__( - self, - dataset: SampleDataset, - feature_keys: List[str], - label_key: str, - mode: str, - pretrained=False, - num_layers=18, - **kwargs, - ): - super(ResNet, self).__init__( - dataset=dataset, - feature_keys=feature_keys, - label_key=label_key, - mode=mode, - ) - cnn_name = f"resnet{num_layers}" - self.cnn = models.__dict__[cnn_name](pretrained=pretrained) - hidden_dim = self.cnn.fc.in_features - self.label_tokenizer = self.get_label_tokenizer() - output_size = self.get_output_size(self.label_tokenizer) - self.cnn.fc = nn.Linear(hidden_dim, output_size) - - def forward(self, **kwargs) -> Dict[str, torch.Tensor]: - """Forward propagation.""" - # concat the info within one batch (batch, channel, length) - x = kwargs[self.feature_keys[0]] - x = torch.stack(x, dim=0).to(self.device) - if x.shape[1] == 1: - x = x.repeat((1, 3, 1, 1)) - logits = self.cnn(x) - y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) - loss = self.get_loss_function()(logits, y_true) - y_prob = self.prepare_y_prob(logits) - return { - "loss": loss, - "y_prob": y_prob, - "y_true": y_true, - } - - -if __name__ == "__main__": - from pyhealth.datasets import COVID19CXRDataset, get_dataloader - from torchvision import transforms - - base_dataset = COVID19CXRDataset( - root="/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset", - ) - - sample_dataset = base_dataset.set_task() - - transform = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.Normalize(mean=[0.5862785803043838], std=[0.27950088968644304]) - ]) - def encode(sample): - sample["path"] = transform(sample["path"]) - return sample - - sample_dataset.set_transform(encode) - - train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True) - - model = ResNet( - dataset=sample_dataset, - feature_keys=[ - "path", - ], - label_key="label", - mode="multiclass", - ) - - # data batch - data_batch = next(iter(train_loader)) - - # try the model - ret = model(**data_batch) - print(ret) - - # try loss backward - ret["loss"].backward() \ No newline at end of file diff --git a/pyhealth/models/torchvision_classif.py b/pyhealth/models/torchvision_classif.py new file mode 100644 index 00000000..aec7ef20 --- /dev/null +++ b/pyhealth/models/torchvision_classif.py @@ -0,0 +1,214 @@ +from typing import List, Dict + +import torch +import torch.nn as nn +from torchvision import models + +from pyhealth.datasets.sample_dataset_v2 import SampleDataset +from pyhealth.models import BaseModel + + +class TorchvisionClassification(BaseModel): + """Torchvision model for image classification. + -----------------------------------ResNet--------------------------------------------- + Paper: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. + Deep Residual Learning for Image Recognition. CVPR 2016. + -----------------------------------DenseNet------------------------------------------- + Paper: Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger. + Densely Connected Convolutional Networks. CVPR 2017. + ----------------------------Vision Transformer (ViT)---------------------------------- + Paper: Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, + Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, + Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021. + ----------------------------Swin Transformer (and V2)--------------------------------- + Paper: Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, + Baining Guo. + Swin Transformer: Hierarchical Vision Transformer Using Shifted Windows. ICCV 2021. + Paper: Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, + Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. + Swin Transformer V2: Scaling Up Capacity and Resolution. CVPR 2022. + -------------------------------------------------------------------------------------- + Args: + dataset: the dataset to train the model. It is used to query certain + information such as the set of all tokens. + feature_keys: list of keys in samples to use as features, + e.g. ["conditions", "procedures"]. + label_key: key in samples to use as label (e.g., "drugs"). + mode: one of "binary", "multiclass", or "multilabel". + pretrained: whether to use pretrained weights. Default is False. + model_parameters: dict, {"name" : str, + "num_layers": int, + "model_size": str, + "patch_size": int} + + Note that for different models, the items in model_parameters vary! + model_parameters['name'] is one of "resnet", "densenet", "vit", "swin", "swin_v2" + For ResNet: + model_parameters = {"name": "resnet", + "num_layers": int} + "num_layers" is one of 18, 34, 50, 101, 152 + For DenseNet: + model_parameters = {"name": "densenet", + "num_layers": int} + "num_layers" is one of 121, 161, 169, 201 + For Vision Transformer: + model_parameters = {"name": "vit", + "model_config": str} + "model_config" is one of 'b_16', 'b_32', 'l_16', 'l_32', 'h_14' + For Swin Transformer: + model_parameters = {"name": "swin", + "model_size": str} + "model_config" is one of 't', 's', 'b' + For Swin Transformer V2: + model_parameters = {"name": "swin_v2", + "model_config": str} + "model_config" is one of 't', 's', 'b' + -------------------------------------------------------------------------------------- + Reference: + Torchvision: https://github.com/mlverse/torchvision + """ + + def __init__( + self, + dataset: SampleDataset, + feature_keys: List[str], + label_key: str, + mode: str, + model_parameters: dict, + pretrained=False, + **kwargs, + ): + super(TorchvisionClassification, self).__init__( + dataset=dataset, + feature_keys=feature_keys, + label_key=label_key, + mode=mode, + ) + self.model_name = model_parameters['name'] + + if self.model_name == 'resnet': + num_layers = model_parameters['num_layers'] + supported_num_layers = [18, 34, 50, 101, 152] + try: + supported_num_layers.index(num_layers) + dnn_name = f"{self.model_name}{num_layers}" + self.dnn = models.__dict__[dnn_name](pretrained=pretrained) + hidden_dim = self.dnn.fc.in_features + self.label_tokenizer = self.get_label_tokenizer() + output_size = self.get_output_size(self.label_tokenizer) + self.dnn.fc = nn.Linear(hidden_dim, output_size) + except: + raise SystemExit('PyTorch does not provide this number of learnable layers for ResNet\ + \nThe candidate number is one of 18, 34, 50, 101, 152') + elif self.model_name == 'densenet': + num_layers = model_parameters['num_layers'] + supported_num_layers = [121, 161, 169, 201] + try: + supported_num_layers.index(num_layers) + dnn_name = f"{self.model_name}{num_layers}" + self.dnn = models.__dict__[dnn_name](pretrained=pretrained) + num_ftrs = self.dnn.classifier.in_features + self.label_tokenizer = self.get_label_tokenizer() + output_size = self.get_output_size(self.label_tokenizer) + self.dnn.classifier = nn.Linear(num_ftrs, output_size) + except: + raise SystemExit('PyTorch does not provide this number of learnable layers for DenseNet\ + \nThe candidate number is one of 121, 161, 169, 201') + elif self.model_name == 'vit': + model_config = model_parameters['model_config'] + supported_model_config = ['b_16', 'b_32', 'l_16', 'l_32', 'h_14'] + try: + supported_model_config.index(model_config) + dnn_name = f"{self.model_name}_{model_config}" + self.dnn = models.__dict__[dnn_name](pretrained=pretrained) + num_ftrs = self.dnn.heads.head.in_features + self.label_tokenizer = self.get_label_tokenizer() + output_size = self.get_output_size(self.label_tokenizer) + self.dnn.heads.head = nn.Linear(num_ftrs, output_size) + except: + raise SystemExit('PyTorch does not provide this model configration for Vision Transformer\ + \nThe candidate is one of \'b_16\', \'b_32\', \'l_16\', \'l_32\', \'h_14\'') + elif self.model_name == 'swin' or self.model_name == 'swin_v2': + model_size = model_parameters['model_size'] + supported_model_size = ['t', 's', 'b'] + try: + supported_model_size.index(model_size) + dnn_name = f"{self.model_name}_{model_size}" + self.dnn = models.__dict__[dnn_name](pretrained=pretrained) + num_ftrs = self.dnn.head.in_features + self.label_tokenizer = self.get_label_tokenizer() + output_size = self.get_output_size(self.label_tokenizer) + self.dnn.head = nn.Linear(num_ftrs, output_size) + except: + raise SystemExit('PyTorch does not provide this model size for Swin Transformer and Swin Transformer V2\ + \nThe candidate is one of \'t\', \'s\', \'b\'') + else: + raise SystemExit(f'ERROR: PyHealth does not currently include {self.model_name} model!') + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation.""" + # concat the info within one batch (batch, channel, length) + x = kwargs[self.feature_keys[0]] + x = torch.stack(x, dim=0).to(self.device) + if x.shape[1] == 1: + x = x.repeat((1, 3, 1, 1)) + logits = self.dnn(x) + y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + } + + +if __name__ == "__main__": + from pyhealth.datasets.utils import get_dataloader + from torchvision import transforms + from pyhealth.datasets import COVID19CXRDataset + + base_dataset = COVID19CXRDataset( + root="/home/wuzijian1231/Datasets/COVID-19_Radiography_Dataset", + ) + + sample_dataset = base_dataset.set_task() + + transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.Normalize(mean=[0.5862785803043838], std=[0.27950088968644304]) + ]) + def encode(sample): + sample["path"] = transform(sample["path"]) + return sample + + sample_dataset.set_transform(encode) + + train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True) + + # model_parameters = {'name':'resnet', 'num_layers':18} + # model_parameters = {'name':'densenet', 'num_layers':121} + # model_parameters = {'name':'vit', 'model_config':'b_16'} + # model_parameters = {'name':'swin', 'model_size':'t'} + model_parameters = {'name':'swin_v2', 'model_size':'t'} + + model = TorchvisionClassification( + dataset=sample_dataset, + feature_keys=[ + "path", + ], + label_key="label", + mode="multiclass", + model_parameters=model_parameters, + ) + + # data batch + data_batch = next(iter(train_loader)) + + # try the model + ret = model(**data_batch) + print(ret) + + # try loss backward + ret["loss"].backward() \ No newline at end of file From 8055fef173ba1be7706ecdacfedbfe8b91dbb52e Mon Sep 17 00:00:00 2001 From: zzachw Date: Fri, 23 Jun 2023 12:37:41 -0400 Subject: [PATCH 11/16] update torchvision model --- pyhealth/models/__init__.py | 4 +- pyhealth/models/torchvision_classif.py | 214 ------------------------- pyhealth/models/torchvision_model.py | 179 +++++++++++++++++++++ 3 files changed, 181 insertions(+), 216 deletions(-) delete mode 100644 pyhealth/models/torchvision_classif.py create mode 100644 pyhealth/models/torchvision_model.py diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index e61ee653..0ebbec30 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -23,5 +23,5 @@ from .stagenet import StageNet, StageNetLayer from .tcn import TCN, TCNLayer from .molerec import MoleRec, MoleRecLayer -from .resnet import ResNet -from .huggingface import HuggingfaceAutoModel +from .torchvision_model import TorchvisionModel +from .transformers_model import TransformersModel diff --git a/pyhealth/models/torchvision_classif.py b/pyhealth/models/torchvision_classif.py deleted file mode 100644 index aec7ef20..00000000 --- a/pyhealth/models/torchvision_classif.py +++ /dev/null @@ -1,214 +0,0 @@ -from typing import List, Dict - -import torch -import torch.nn as nn -from torchvision import models - -from pyhealth.datasets.sample_dataset_v2 import SampleDataset -from pyhealth.models import BaseModel - - -class TorchvisionClassification(BaseModel): - """Torchvision model for image classification. - -----------------------------------ResNet--------------------------------------------- - Paper: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. - Deep Residual Learning for Image Recognition. CVPR 2016. - -----------------------------------DenseNet------------------------------------------- - Paper: Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger. - Densely Connected Convolutional Networks. CVPR 2017. - ----------------------------Vision Transformer (ViT)---------------------------------- - Paper: Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, - Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, - Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021. - ----------------------------Swin Transformer (and V2)--------------------------------- - Paper: Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, - Baining Guo. - Swin Transformer: Hierarchical Vision Transformer Using Shifted Windows. ICCV 2021. - Paper: Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, - Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. - Swin Transformer V2: Scaling Up Capacity and Resolution. CVPR 2022. - -------------------------------------------------------------------------------------- - Args: - dataset: the dataset to train the model. It is used to query certain - information such as the set of all tokens. - feature_keys: list of keys in samples to use as features, - e.g. ["conditions", "procedures"]. - label_key: key in samples to use as label (e.g., "drugs"). - mode: one of "binary", "multiclass", or "multilabel". - pretrained: whether to use pretrained weights. Default is False. - model_parameters: dict, {"name" : str, - "num_layers": int, - "model_size": str, - "patch_size": int} - - Note that for different models, the items in model_parameters vary! - model_parameters['name'] is one of "resnet", "densenet", "vit", "swin", "swin_v2" - For ResNet: - model_parameters = {"name": "resnet", - "num_layers": int} - "num_layers" is one of 18, 34, 50, 101, 152 - For DenseNet: - model_parameters = {"name": "densenet", - "num_layers": int} - "num_layers" is one of 121, 161, 169, 201 - For Vision Transformer: - model_parameters = {"name": "vit", - "model_config": str} - "model_config" is one of 'b_16', 'b_32', 'l_16', 'l_32', 'h_14' - For Swin Transformer: - model_parameters = {"name": "swin", - "model_size": str} - "model_config" is one of 't', 's', 'b' - For Swin Transformer V2: - model_parameters = {"name": "swin_v2", - "model_config": str} - "model_config" is one of 't', 's', 'b' - -------------------------------------------------------------------------------------- - Reference: - Torchvision: https://github.com/mlverse/torchvision - """ - - def __init__( - self, - dataset: SampleDataset, - feature_keys: List[str], - label_key: str, - mode: str, - model_parameters: dict, - pretrained=False, - **kwargs, - ): - super(TorchvisionClassification, self).__init__( - dataset=dataset, - feature_keys=feature_keys, - label_key=label_key, - mode=mode, - ) - self.model_name = model_parameters['name'] - - if self.model_name == 'resnet': - num_layers = model_parameters['num_layers'] - supported_num_layers = [18, 34, 50, 101, 152] - try: - supported_num_layers.index(num_layers) - dnn_name = f"{self.model_name}{num_layers}" - self.dnn = models.__dict__[dnn_name](pretrained=pretrained) - hidden_dim = self.dnn.fc.in_features - self.label_tokenizer = self.get_label_tokenizer() - output_size = self.get_output_size(self.label_tokenizer) - self.dnn.fc = nn.Linear(hidden_dim, output_size) - except: - raise SystemExit('PyTorch does not provide this number of learnable layers for ResNet\ - \nThe candidate number is one of 18, 34, 50, 101, 152') - elif self.model_name == 'densenet': - num_layers = model_parameters['num_layers'] - supported_num_layers = [121, 161, 169, 201] - try: - supported_num_layers.index(num_layers) - dnn_name = f"{self.model_name}{num_layers}" - self.dnn = models.__dict__[dnn_name](pretrained=pretrained) - num_ftrs = self.dnn.classifier.in_features - self.label_tokenizer = self.get_label_tokenizer() - output_size = self.get_output_size(self.label_tokenizer) - self.dnn.classifier = nn.Linear(num_ftrs, output_size) - except: - raise SystemExit('PyTorch does not provide this number of learnable layers for DenseNet\ - \nThe candidate number is one of 121, 161, 169, 201') - elif self.model_name == 'vit': - model_config = model_parameters['model_config'] - supported_model_config = ['b_16', 'b_32', 'l_16', 'l_32', 'h_14'] - try: - supported_model_config.index(model_config) - dnn_name = f"{self.model_name}_{model_config}" - self.dnn = models.__dict__[dnn_name](pretrained=pretrained) - num_ftrs = self.dnn.heads.head.in_features - self.label_tokenizer = self.get_label_tokenizer() - output_size = self.get_output_size(self.label_tokenizer) - self.dnn.heads.head = nn.Linear(num_ftrs, output_size) - except: - raise SystemExit('PyTorch does not provide this model configration for Vision Transformer\ - \nThe candidate is one of \'b_16\', \'b_32\', \'l_16\', \'l_32\', \'h_14\'') - elif self.model_name == 'swin' or self.model_name == 'swin_v2': - model_size = model_parameters['model_size'] - supported_model_size = ['t', 's', 'b'] - try: - supported_model_size.index(model_size) - dnn_name = f"{self.model_name}_{model_size}" - self.dnn = models.__dict__[dnn_name](pretrained=pretrained) - num_ftrs = self.dnn.head.in_features - self.label_tokenizer = self.get_label_tokenizer() - output_size = self.get_output_size(self.label_tokenizer) - self.dnn.head = nn.Linear(num_ftrs, output_size) - except: - raise SystemExit('PyTorch does not provide this model size for Swin Transformer and Swin Transformer V2\ - \nThe candidate is one of \'t\', \'s\', \'b\'') - else: - raise SystemExit(f'ERROR: PyHealth does not currently include {self.model_name} model!') - - def forward(self, **kwargs) -> Dict[str, torch.Tensor]: - """Forward propagation.""" - # concat the info within one batch (batch, channel, length) - x = kwargs[self.feature_keys[0]] - x = torch.stack(x, dim=0).to(self.device) - if x.shape[1] == 1: - x = x.repeat((1, 3, 1, 1)) - logits = self.dnn(x) - y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) - loss = self.get_loss_function()(logits, y_true) - y_prob = self.prepare_y_prob(logits) - return { - "loss": loss, - "y_prob": y_prob, - "y_true": y_true, - } - - -if __name__ == "__main__": - from pyhealth.datasets.utils import get_dataloader - from torchvision import transforms - from pyhealth.datasets import COVID19CXRDataset - - base_dataset = COVID19CXRDataset( - root="/home/wuzijian1231/Datasets/COVID-19_Radiography_Dataset", - ) - - sample_dataset = base_dataset.set_task() - - transform = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.Normalize(mean=[0.5862785803043838], std=[0.27950088968644304]) - ]) - def encode(sample): - sample["path"] = transform(sample["path"]) - return sample - - sample_dataset.set_transform(encode) - - train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True) - - # model_parameters = {'name':'resnet', 'num_layers':18} - # model_parameters = {'name':'densenet', 'num_layers':121} - # model_parameters = {'name':'vit', 'model_config':'b_16'} - # model_parameters = {'name':'swin', 'model_size':'t'} - model_parameters = {'name':'swin_v2', 'model_size':'t'} - - model = TorchvisionClassification( - dataset=sample_dataset, - feature_keys=[ - "path", - ], - label_key="label", - mode="multiclass", - model_parameters=model_parameters, - ) - - # data batch - data_batch = next(iter(train_loader)) - - # try the model - ret = model(**data_batch) - print(ret) - - # try loss backward - ret["loss"].backward() \ No newline at end of file diff --git a/pyhealth/models/torchvision_model.py b/pyhealth/models/torchvision_model.py new file mode 100644 index 00000000..e93b7d84 --- /dev/null +++ b/pyhealth/models/torchvision_model.py @@ -0,0 +1,179 @@ +from typing import List, Dict + +import torch +import torch.nn as nn +import torchvision + +from pyhealth.datasets.sample_dataset_v2 import SampleDataset +from pyhealth.models import BaseModel + +SUPPORTED_MODELS = [ + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "densenet121", + "densenet161", + "densenet169", + "densenet201", + "vit_b_16", + "vit_b_32", + "vit_l_16", + "vit_l_32", + "vit_h_14", + "swin_t", + "swin_s", + "swin_b", +] + +SUPPORTED_MODELS_FINAL_LAYER = {} +for model in SUPPORTED_MODELS: + if "resnet" in model: + SUPPORTED_MODELS_FINAL_LAYER[model] = "fc" + elif "densenet" in model: + SUPPORTED_MODELS_FINAL_LAYER[model] = "classifier" + elif "vit" in model: + SUPPORTED_MODELS_FINAL_LAYER[model] = "heads" + elif "swin" in model: + SUPPORTED_MODELS_FINAL_LAYER[model] = "head" + else: + raise NotImplementedError + + +class TorchvisionModel(BaseModel): + """Models from PyTorch's torchvision package. + + This class is a wrapper for models from torchvision. It will automatically load + the corresponding model and weights from torchvision. The final layer will be + replaced with a linear layer with the correct output size. + + -----------------------------------ResNet------------------------------------------ + Paper: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. Deep Residual Learning + for Image Recognition. CVPR 2016. + -----------------------------------DenseNet---------------------------------------- + Paper: Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger. + Densely Connected Convolutional Networks. CVPR 2017. + ----------------------------Vision Transformer (ViT)------------------------------- + Paper: Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, + Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, + Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. An Image is Worth + 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021. + ----------------------------Swin Transformer (and V2)------------------------------ + Paper: Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, + Baining Guo. Swin Transformer: Hierarchical Vision Transformer Using Shifted + Windows. ICCV 2021. + + Paper: Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, + Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. Swin Transformer V2: Scaling + Up Capacity and Resolution. CVPR 2022. + ----------------------------------------------------------------------------------- + + Args: + dataset: the dataset to train the model. It is used to query certain + information such as the set of all tokens. + feature_keys: list of keys in samples to use as features, e.g., ["image"]. + Only one feature is supported. + label_key: key in samples to use as label, e.g., "drugs". + mode: one of "binary", "multiclass", or "multilabel". + model_name: str, name of the model to use, e.g., "resnet18". + See SUPPORTED_MODELS in the source code for the full list. + model_config: dict, kwargs to pass to the model constructor, + e.g., {"weights": "DEFAULT"}. See the torchvision documentation for the + set of supported kwargs for each model. + ----------------------------------------------------------------------------------- + """ + + def __init__( + self, + dataset: SampleDataset, + feature_keys: List[str], + label_key: str, + mode: str, + model_name: str, + model_config: dict, + ): + super(TorchvisionModel, self).__init__( + dataset=dataset, + feature_keys=feature_keys, + label_key=label_key, + mode=mode, + ) + + self.model_name = model_name + self.model_config = model_config + + assert len(feature_keys) == 1, "Only one feature is supported!" + assert model_name in SUPPORTED_MODELS_FINAL_LAYER.keys(), \ + f"PyHealth does not currently include {model_name} model!" + + self.model = torchvision.models.get_model(model_name, **model_config) + final_layer_name = SUPPORTED_MODELS_FINAL_LAYER[model_name] + hidden_dim = getattr(self.model, final_layer_name).in_features + self.label_tokenizer = self.get_label_tokenizer() + output_size = self.get_output_size(self.label_tokenizer) + setattr(self.model, final_layer_name, nn.Linear(hidden_dim, output_size)) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation.""" + # concat the info within one batch (batch, channel, length) + x = kwargs[self.feature_keys[0]] + x = torch.stack(x, dim=0).to(self.device) + if x.shape[1] == 1: + x = x.repeat((1, 3, 1, 1)) + logits = self.model(x) + y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + } + + +if __name__ == "__main__": + from pyhealth.datasets.utils import get_dataloader + from torchvision import transforms + from pyhealth.datasets import COVID19CXRDataset + + base_dataset = COVID19CXRDataset( + root="/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset", + ) + + sample_dataset = base_dataset.set_task() + + transform = transforms.Compose([ + transforms.Grayscale(), + transforms.Resize((224, 224)), + transforms.Normalize(mean=[0.5862785803043838], std=[0.27950088968644304]) + ]) + + + def encode(sample): + sample["path"] = transform(sample["path"]) + return sample + + + sample_dataset.set_transform(encode) + + train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True) + + model = TorchvisionModel( + dataset=sample_dataset, + feature_keys=["path"], + label_key="label", + mode="multiclass", + model_name="swin_t", + model_config={"weights": "DEFAULT"}, + ) + + # data batch + data_batch = next(iter(train_loader)) + + # try the model + ret = model(**data_batch) + print(ret) + + # try loss backward + ret["loss"].backward() From 5b1cb90398c8d75aeac75f5f9fd7019280858446 Mon Sep 17 00:00:00 2001 From: zzachw Date: Wed, 12 Jul 2023 16:45:33 -0400 Subject: [PATCH 12/16] tmp commit --- .../medical_transcriptions_classification.py | 4 +-- .../{huggingface.py => transformers_model.py} | 25 +++++++------------ 2 files changed, 11 insertions(+), 18 deletions(-) rename pyhealth/models/{huggingface.py => transformers_model.py} (85%) diff --git a/examples/medical_transcriptions_classification.py b/examples/medical_transcriptions_classification.py index fb1fe324..041c855c 100644 --- a/examples/medical_transcriptions_classification.py +++ b/examples/medical_transcriptions_classification.py @@ -3,7 +3,7 @@ from pyhealth.datasets import MedicalTranscriptionsDataset from pyhealth.datasets import get_dataloader -from pyhealth.models import HuggingfaceAutoModel +from pyhealth.models import TransformersModel from pyhealth.trainer import Trainer root = "/srv/local/data/zw12/raw_data/MedicalTranscriptions" @@ -27,7 +27,7 @@ val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) -model = HuggingfaceAutoModel( +model = TransformersModel( model_name="emilyalsentzer/Bio_ClinicalBERT", dataset=sample_dataset, feature_keys=["transcription"], diff --git a/pyhealth/models/huggingface.py b/pyhealth/models/transformers_model.py similarity index 85% rename from pyhealth/models/huggingface.py rename to pyhealth/models/transformers_model.py index d810f854..cee6f10e 100644 --- a/pyhealth/models/huggingface.py +++ b/pyhealth/models/transformers_model.py @@ -1,30 +1,26 @@ from typing import List, Dict import torch -from transformers import AutoModel, AutoTokenizer import torch.nn as nn -from torchvision import models +from transformers import AutoModel, AutoTokenizer from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel -class HuggingfaceAutoModel(BaseModel): - """AutoModel class for Huggingface models. +class TransformersModel(BaseModel): + """Transformers class for Huggingface models. """ def __init__( self, - model_name: str, dataset: SampleDataset, feature_keys: List[str], label_key: str, mode: str, - pretrained=False, - num_layers=18, - **kwargs, + model_name: str, ): - super(HuggingfaceAutoModel, self).__init__( + super(TransformersModel, self).__init__( dataset=dataset, feature_keys=feature_keys, label_key=label_key, @@ -60,7 +56,6 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: if __name__ == "__main__": from pyhealth.datasets import MedicalTranscriptionsDataset, get_dataloader - from torchvision import transforms base_dataset = MedicalTranscriptionsDataset( root="/srv/local/data/zw12/raw_data/MedicalTranscriptions" @@ -70,14 +65,12 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True) - model = HuggingfaceAutoModel( - model_name="emilyalsentzer/Bio_ClinicalBERT", + model = TransformersModel( dataset=sample_dataset, - feature_keys=[ - "transcription", - ], + feature_keys=["transcription"], label_key="label", mode="multiclass", + model_name="emilyalsentzer/Bio_ClinicalBERT", ) # data batch @@ -88,4 +81,4 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: print(ret) # try loss backward - ret["loss"].backward() \ No newline at end of file + ret["loss"].backward() From 47ee96fcac10708bda533d2b2385e1a77bbf947d Mon Sep 17 00:00:00 2001 From: zzachw Date: Sat, 15 Jul 2023 19:51:56 -0400 Subject: [PATCH 13/16] wrap up this pr - add split by sample --- pyhealth/datasets/splitter.py | 51 ++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/pyhealth/datasets/splitter.py b/pyhealth/datasets/splitter.py index 3ffc1df8..ccfb7b14 100644 --- a/pyhealth/datasets/splitter.py +++ b/pyhealth/datasets/splitter.py @@ -38,9 +38,10 @@ def split_by_visit( np.random.shuffle(index) train_index = index[: int(len(dataset) * ratios[0])] val_index = index[ - int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) - ] - test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] + int(len(dataset) * ratios[0]): int( + len(dataset) * (ratios[0] + ratios[1])) + ] + test_index = index[int(len(dataset) * (ratios[0] + ratios[1])):] train_dataset = torch.utils.data.Subset(dataset, train_index) val_dataset = torch.utils.data.Subset(dataset, val_index) test_dataset = torch.utils.data.Subset(dataset, test_index) @@ -75,9 +76,10 @@ def split_by_patient( np.random.shuffle(patient_indx) train_patient_indx = patient_indx[: int(num_patients * ratios[0])] val_patient_indx = patient_indx[ - int(num_patients * ratios[0]) : int(num_patients * (ratios[0] + ratios[1])) - ] - test_patient_indx = patient_indx[int(num_patients * (ratios[0] + ratios[1])) :] + int(num_patients * ratios[0]): int( + num_patients * (ratios[0] + ratios[1])) + ] + test_patient_indx = patient_indx[int(num_patients * (ratios[0] + ratios[1])):] train_index = list( chain(*[dataset.patient_to_index[i] for i in train_patient_indx]) ) @@ -87,3 +89,40 @@ def split_by_patient( val_dataset = torch.utils.data.Subset(dataset, val_index) test_dataset = torch.utils.data.Subset(dataset, test_index) return train_dataset, val_dataset, test_dataset + + +def split_by_sample( + dataset: SampleBaseDataset, + ratios: Union[Tuple[float, float, float], List[float]], + seed: Optional[int] = None, +): + """Splits the dataset by sample + + Args: + dataset: a `SampleBaseDataset` object + ratios: a list/tuple of ratios for train / val / test + seed: random seed for shuffling the dataset + + Returns: + train_dataset, val_dataset, test_dataset: three subsets of the dataset of + type `torch.utils.data.Subset`. + + Note: + The original dataset can be accessed by `train_dataset.dataset`, + `val_dataset.dataset`, and `test_dataset.dataset`. + """ + if seed is not None: + np.random.seed(seed) + assert sum(ratios) == 1.0, "ratios must sum to 1.0" + index = np.arange(len(dataset)) + np.random.shuffle(index) + train_index = index[: int(len(dataset) * ratios[0])] + val_index = index[ + int(len(dataset) * ratios[0]): int( + len(dataset) * (ratios[0] + ratios[1])) + ] + test_index = index[int(len(dataset) * (ratios[0] + ratios[1])):] + train_dataset = torch.utils.data.Subset(dataset, train_index) + val_dataset = torch.utils.data.Subset(dataset, val_index) + test_dataset = torch.utils.data.Subset(dataset, test_index) + return train_dataset, val_dataset, test_dataset From 6f5cf5382faa7d5851f6a3d24d93d9ceca91325b Mon Sep 17 00:00:00 2001 From: zzachw Date: Sat, 15 Jul 2023 19:53:14 -0400 Subject: [PATCH 14/16] update init --- pyhealth/datasets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 0c0a9565..b4401ccd 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -13,5 +13,5 @@ from .sample_dataset_v2 import SampleDataset from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset -from .splitter import split_by_patient, split_by_visit +from .splitter import split_by_patient, split_by_visit, split_by_sample from .utils import collate_fn_dict, get_dataloader, strptime From 19e5abd69901096bd435300e5d5bf0cec1543eb3 Mon Sep 17 00:00:00 2001 From: zzachw Date: Sat, 15 Jul 2023 20:10:13 -0400 Subject: [PATCH 15/16] wrap up the pr: import bug fix --- pyhealth/models/torchvision_model.py | 9 ++++++--- pyhealth/tasks/covid19_cxr_classification.py | 2 +- pyhealth/tasks/medical_transcriptions_classification.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pyhealth/models/torchvision_model.py b/pyhealth/models/torchvision_model.py index e93b7d84..6173de35 100644 --- a/pyhealth/models/torchvision_model.py +++ b/pyhealth/models/torchvision_model.py @@ -34,7 +34,7 @@ elif "densenet" in model: SUPPORTED_MODELS_FINAL_LAYER[model] = "classifier" elif "vit" in model: - SUPPORTED_MODELS_FINAL_LAYER[model] = "heads" + SUPPORTED_MODELS_FINAL_LAYER[model] = "heads.head" elif "swin" in model: SUPPORTED_MODELS_FINAL_LAYER[model] = "head" else: @@ -109,7 +109,10 @@ def __init__( self.model = torchvision.models.get_model(model_name, **model_config) final_layer_name = SUPPORTED_MODELS_FINAL_LAYER[model_name] - hidden_dim = getattr(self.model, final_layer_name).in_features + final_layer = self.model + for name in final_layer_name.split("."): + final_layer = getattr(final_layer, name) + hidden_dim = final_layer.in_features self.label_tokenizer = self.get_label_tokenizer() output_size = self.get_output_size(self.label_tokenizer) setattr(self.model, final_layer_name, nn.Linear(hidden_dim, output_size)) @@ -164,7 +167,7 @@ def encode(sample): feature_keys=["path"], label_key="label", mode="multiclass", - model_name="swin_t", + model_name="vit_b_16", model_config={"weights": "DEFAULT"}, ) diff --git a/pyhealth/tasks/covid19_cxr_classification.py b/pyhealth/tasks/covid19_cxr_classification.py index 25c494f7..bbb5b364 100644 --- a/pyhealth/tasks/covid19_cxr_classification.py +++ b/pyhealth/tasks/covid19_cxr_classification.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import Dict -from pyhealth.tasks import TaskTemplate +from pyhealth.tasks.task_template import TaskTemplate @dataclass(frozen=True) diff --git a/pyhealth/tasks/medical_transcriptions_classification.py b/pyhealth/tasks/medical_transcriptions_classification.py index 2f7fa5b3..1dbcff12 100644 --- a/pyhealth/tasks/medical_transcriptions_classification.py +++ b/pyhealth/tasks/medical_transcriptions_classification.py @@ -2,7 +2,7 @@ from typing import Dict import pandas as pd -from pyhealth.tasks import TaskTemplate +from pyhealth.tasks.task_template import TaskTemplate @dataclass(frozen=True) From ec3c1aa4b709cc166ca1ad074ef08cfb3ec78693 Mon Sep 17 00:00:00 2001 From: zzachw Date: Sat, 15 Jul 2023 20:14:20 -0400 Subject: [PATCH 16/16] wrap up the pr: bug fix --- pyhealth/models/torchvision_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/models/torchvision_model.py b/pyhealth/models/torchvision_model.py index 6173de35..a6d66e15 100644 --- a/pyhealth/models/torchvision_model.py +++ b/pyhealth/models/torchvision_model.py @@ -115,7 +115,7 @@ def __init__( hidden_dim = final_layer.in_features self.label_tokenizer = self.get_label_tokenizer() output_size = self.get_output_size(self.label_tokenizer) - setattr(self.model, final_layer_name, nn.Linear(hidden_dim, output_size)) + setattr(self.model, final_layer_name.split(".")[0], nn.Linear(hidden_dim, output_size)) def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation."""