From 293dc69bf4b167ee49259cc279d11eb6b26ee5a3 Mon Sep 17 00:00:00 2001 From: samarthkeshari Date: Mon, 17 Apr 2023 22:22:00 -0500 Subject: [PATCH 01/11] added imagecaption dataset + reportgeneration task --- .../datasets/base_image_caption_dataset.py | 145 ++++++++++++++++++ pyhealth/tasks/xray_report_generation.py | 18 +++ 2 files changed, 163 insertions(+) create mode 100644 pyhealth/datasets/base_image_caption_dataset.py create mode 100644 pyhealth/tasks/xray_report_generation.py diff --git a/pyhealth/datasets/base_image_caption_dataset.py b/pyhealth/datasets/base_image_caption_dataset.py new file mode 100644 index 00000000..24799430 --- /dev/null +++ b/pyhealth/datasets/base_image_caption_dataset.py @@ -0,0 +1,145 @@ +import logging +import os +from abc import ABC +from typing import Optional, Callable + +import pandas as pd +from tqdm import tqdm + +#from pyhealth.datasets.sample_dataset import SampleImageDataset + +logger = logging.getLogger(__name__) + +INFO_MSG = """ +dataset.patients: + - key: patient id + - value: a dict of image paths, captions, and other information +""" + + +class BaseImageCaptionGenerationDataset(ABC): + """Abstract base Image Caption Generation dataset class. + + This abstract class defines a uniform interface for all + image caption generation datasets. + + Each specific dataset will be a subclass of this abstract class, which can then + be converted to samples dataset for different tasks by calling `self.set_task()`. + + Args: + root: root directory of the raw data (should contain many csv files). + dataset_name: name of the dataset. Default is the name of the class. + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. Default is False. + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + dev: bool = False, + refresh_cache: bool = False, + ): + # base attributes + self.dataset_name = ( + self.__class__.__name__ if dataset_name is None else dataset_name + ) + self.root = root + # TODO: dev seems unnecessary for image and signal? + self.dev = dev + if dev: + logger.warning("WARNING: dev has no effect for image caption generation datasets.") + # TODO: refresh_cache seems unnecessary for image and signal? + self.refresh_cache = refresh_cache + if refresh_cache: + logger.warning("WARNING: refresh_cache has no effect for image caption generation datasets.") + + self.metadata = pd.read_json(os.path.join(root, "metadata.jsonl"), lines=True) + #self.metadata["path"] = self.metadata["path"].apply( + # lambda x: os.path.join(root, x) + #) + if "patient_id" not in self.metadata.columns: + # no patient_id in metadata, sequentially assign patient_id + self.metadata["patient_id"] = self.metadata.index + + # group by patient_id + self.patients = dict() + for patient_id, group in self.metadata.groupby("patient_id"): + self.patients[patient_id] = group.to_dict(orient="records") + + return + + def __len__(self): + return len(self.patients) + + def __str__(self): + """Prints some information of the dataset.""" + return f"Base dataset {self.dataset_name}" + + def stat(self) -> str: + """Returns some statistics of the base dataset.""" + lines = list() + lines.append("") + lines.append(f"Statistics of base dataset (dev={self.dev}):") + lines.append(f"\t- Dataset: {self.dataset_name}") + lines.append(f"\t- Number of images: {len(self)}") + lines.append("") + print("\n".join(lines)) + return "\n".join(lines) + + @staticmethod + def info(): + """Prints the output format.""" + print(INFO_MSG) + + def set_task( + self, + task_fn: Callable, + task_name: Optional[str] = None, + ): #-> SampleImageDataset: + """Processes the base dataset to generate the task-specific sample dataset. + + This function should be called by the user after the base dataset is + initialized. It will iterate through all patients in the base dataset + and call `task_fn` which should be implemented by the specific task. + + Args: + task_fn: a function that takes a single patient and returns a + list of samples (each sample is a dict with patient_id, visit_id, + and other task-specific attributes as key). The samples will be + concatenated to form the sample dataset. + task_name: the name of the task. If None, the name of the task + function will be used. + + Returns: + sample_dataset: the task-specific sample (Base) dataset. + + Note: + In `task_fn`, a patient may have one or multiple images associated + to a caption, for e.g. a +patient can have single report associated + to multiple xrays from diffrent views that may be combined to have + a single sample ({'patient_id':1, 'frontal_image':'img1', + 'lateral_image': 'img2','report':'text}) + Patients can also be excluded from the task dataset by returning + an empty list. + """ + if task_name is None: + task_name = task_fn.__name__ + + # load from raw data + logger.debug(f"Processing {self.dataset_name} base dataset...") + + samples = [] + for patient_id, patient in tqdm( + self.patients.items(), desc=f"Generating samples for {task_name}"): + samples.extend(task_fn(patient)) + + #sample_dataset = SampleImageDataset( + # samples, + # dataset_name=self.dataset_name, + # task_name=task_name, + #) + #return sample_dataset + return samples \ No newline at end of file diff --git a/pyhealth/tasks/xray_report_generation.py b/pyhealth/tasks/xray_report_generation.py new file mode 100644 index 00000000..c42b88bd --- /dev/null +++ b/pyhealth/tasks/xray_report_generation.py @@ -0,0 +1,18 @@ +import os + +def biview_onesent_fn(patient): + """ Processes single patient for xray report generation""" + sample = {} + sample['frontal_image_path'] = None + sample['lateral_image_path'] = None + img_root = '/srv/local/data/IU_XRay/images/images_normalized' + + for data in patient: + sample['patient_id'] = data['patient_id'] + sample['report'] = [data['impression']+data['findings']] + if data['view'] == 'frontal': + sample['frontal_image_path'] = os.path.join(img_root, data['path']) + if data['view'] == 'lateral': + sample['lateral_image_path'] = os.path.join(img_root, data['path']) + patient = [sample] + return patient \ No newline at end of file From ae2d25821b2d6ce9659fabdcceda349fa76e7528 Mon Sep 17 00:00:00 2001 From: samarthkeshari Date: Mon, 17 Apr 2023 22:55:10 -0500 Subject: [PATCH 02/11] updated __init__ files --- pyhealth/datasets/__init__.py | 1 + pyhealth/datasets/base_image_caption_dataset.py | 2 +- pyhealth/tasks/__init__.py | 3 +++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 650178e0..c64f16e0 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -1,4 +1,5 @@ from .base_ehr_dataset import BaseEHRDataset +from .base_image_caption_dataset import BaseImageCaptionDataset from .base_signal_dataset import BaseSignalDataset from .eicu import eICUDataset from .mimic3 import MIMIC3Dataset diff --git a/pyhealth/datasets/base_image_caption_dataset.py b/pyhealth/datasets/base_image_caption_dataset.py index 24799430..755d22d5 100644 --- a/pyhealth/datasets/base_image_caption_dataset.py +++ b/pyhealth/datasets/base_image_caption_dataset.py @@ -17,7 +17,7 @@ """ -class BaseImageCaptionGenerationDataset(ABC): +class BaseImageCaptionDataset(ABC): """Abstract base Image Caption Generation dataset class. This abstract class defines a uniform interface for all diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 6d7ffb02..b8207478 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -26,3 +26,6 @@ sleep_staging_sleepedf_fn, sleep_staging_isruc_fn ) +from .xray_report_generation import ( + biview_onesent_fn +) From ac5e583d72ad8e370823dcb92d0f338ad73d0cf8 Mon Sep 17 00:00:00 2001 From: samarthkeshari Date: Sat, 22 Apr 2023 10:38:12 -0500 Subject: [PATCH 03/11] added task, wordsat modell and enhanced trainer --- test.py | 111 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 test.py diff --git a/test.py b/test.py new file mode 100644 index 00000000..9262d93c --- /dev/null +++ b/test.py @@ -0,0 +1,111 @@ +import pickle +from torchvision import transforms +from pyhealth.datasets import BaseImageCaptionDataset +from pyhealth.tasks.xray_report_generation import biview_multisent_fn +from pyhealth.datasets import split_by_patient, get_dataloader +from pyhealth.tokenizer import Tokenizer +from pyhealth.models import WordSAT +from pyhealth.trainer import Trainer +from pyhealth.datasets.utils import list_nested_levels, flatten_list + +import torch +from collections import OrderedDict +def extract_state_dict(chkpt_pth): + checkpoint = torch.load(chkpt_pth + 'model_ones_3epoch_densenet.tar') + new_state_dict = OrderedDict() + for k,v in checkpoint['state_dict'].items(): + if 'classifier' in k: continue + if k[:7] == 'module.' : + name = k[7:] + else: + name = k + + name = name.replace('classifier.0.','classifier.') + new_state_dict[name] = v + return new_state_dict + +chkpt_pth = '/home/keshari2/ChestXrayReporting/IU_XRay/src/models/pretrained/' +state_dict = extract_state_dict(chkpt_pth) + +##### +def seed_everything(seed: int): + import random, os + import numpy as np + import torch + + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + +seed_everything(42) + +root = '/home/keshari2/ChestXrayReporting/IU_XRay/src/data' +sample_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay') +sample_dataset = sample_dataset.set_task(biview_multisent_fn) +transform = transforms.Compose([ + transforms.RandomAffine(degrees=30), + transforms.Resize((512,512)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485,0.456,0.406], + std=[0.229,0.224,0.225]), + ]) +sample_dataset.set_transform(transform) + +""" +special_tokens = ['','','',''] +tokenizer = Tokenizer( + sample_dataset.get_all_tokens(key='caption'), + special_tokens=special_tokens, + ) + +with open(root+'/pyhealth_tokenizer.pkl', 'wb') as f: + pickle.dump(tokenizer, f) +""" +with open(root+'/pyhealth_tokenizer.pkl', 'rb') as f: + tokenizer = pickle.load(f) +#print(sample_dataset[0]['caption']) + +train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset,[0.8,0.1,0.1] +) + +train_dataloader = get_dataloader(train_dataset,batch_size=8,shuffle=True) +val_dataloader = get_dataloader(val_dataset,batch_size=1,shuffle=False) +test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False) + +print(len(train_dataset),len(val_dataset),len(test_dataset)) + +model=WordSAT(dataset=sample_dataset, + feature_keys=['image_1','image_2'], + label_key='caption', + tokenizer=tokenizer, + mode='sequence', + encoder_pretrained_weights=state_dict, + save_generated_caption = True + ) +#model.eval() +#data = next(iter(val_dataloader)) +#print(model(**data)) +""" +output_path = '/home/keshari2/ChestXrayReporting/IU_XRay/src/output/pyhealth' +ckpt_path = '/home/keshari2/ChestXrayReporting/IU_XRay/src/output/pyhealth/20230422-005011/best.ckpt' + +trainer = Trainer( + model=model, + output_path=output_path, + checkpoint_path = ckpt_path + ) +trainer.train( + train_dataloader = train_dataloader, + val_dataloader = val_dataloader, + optimizer_params = {"lr": 1e-4}, + weight_decay = 1e-5, + max_grad_norm = 1, + epochs = 5, + monitor = 'Bleu_1' +) +""" \ No newline at end of file From d690419b75bfe3180512fa61e3a377ab2ba2ce8d Mon Sep 17 00:00:00 2001 From: samarthkeshari Date: Sat, 22 Apr 2023 10:38:40 -0500 Subject: [PATCH 04/11] added task, wordsat model and enhanced trainer --- pyhealth/datasets/__init__.py | 3 +- .../datasets/base_image_caption_dataset.py | 17 +- pyhealth/datasets/sample_dataset.py | 91 ++++++ pyhealth/metrics/__init__.py | 1 + pyhealth/metrics/sequence.py | 31 ++ pyhealth/models/__init__.py | 1 + pyhealth/models/base_model.py | 5 +- pyhealth/models/wordsat.py | 275 ++++++++++++++++++ pyhealth/tasks/__init__.py | 2 +- pyhealth/tasks/xray_report_generation.py | 45 ++- pyhealth/trainer.py | 37 ++- 11 files changed, 482 insertions(+), 26 deletions(-) create mode 100644 pyhealth/metrics/sequence.py create mode 100644 pyhealth/models/wordsat.py diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index c64f16e0..28ef3a71 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -8,6 +8,7 @@ from .sleepedf import SleepEDFDataset from .isruc import ISRUCDataset from .shhs import SHHSDataset -from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset +from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset,\ + SampleImageCaptionDataset from .splitter import split_by_patient, split_by_visit from .utils import collate_fn_dict, get_dataloader, strptime diff --git a/pyhealth/datasets/base_image_caption_dataset.py b/pyhealth/datasets/base_image_caption_dataset.py index 755d22d5..2e472100 100644 --- a/pyhealth/datasets/base_image_caption_dataset.py +++ b/pyhealth/datasets/base_image_caption_dataset.py @@ -6,7 +6,7 @@ import pandas as pd from tqdm import tqdm -#from pyhealth.datasets.sample_dataset import SampleImageDataset +from pyhealth.datasets.sample_dataset import SampleImageCaptionDataset logger = logging.getLogger(__name__) @@ -98,7 +98,7 @@ def set_task( self, task_fn: Callable, task_name: Optional[str] = None, - ): #-> SampleImageDataset: + ) -> SampleImageCaptionDataset: """Processes the base dataset to generate the task-specific sample dataset. This function should be called by the user after the base dataset is @@ -136,10 +136,9 @@ def set_task( self.patients.items(), desc=f"Generating samples for {task_name}"): samples.extend(task_fn(patient)) - #sample_dataset = SampleImageDataset( - # samples, - # dataset_name=self.dataset_name, - # task_name=task_name, - #) - #return sample_dataset - return samples \ No newline at end of file + sample_dataset = SampleImageCaptionDataset( + samples, + dataset_name=self.dataset_name, + task_name=task_name, + ) + return sample_dataset \ No newline at end of file diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 5c83c0da..95d6c1e9 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -2,7 +2,9 @@ from typing import Dict, List import pickle +from PIL import Image from torch.utils.data import Dataset +from torchvision import transforms from pyhealth.datasets.utils import list_nested_levels, flatten_list @@ -498,6 +500,95 @@ def stat(self) -> str: print("\n".join(lines)) return "\n".join(lines) +class SampleImageCaptionDataset(SampleBaseDataset): + """Sample image caption dataset class. + + This class the takes a list of samples as input (either from + `BaseDataset.set_task()` or user-provided input), and provides + a uniform interface for accessing the samples. + + Args: + samples: a list of samples + dataset_name: the name of the dataset. Default is None. + task_name: the name of the task. Default is None. + """ + + def __init__( + self, + samples: List[Dict], + dataset_name: str = "", + task_name: str = "" + ): + super().__init__(samples, dataset_name, task_name) + self.patient_to_index: Dict[str, List[int]] = self._index_patient() + self.type_ = "image_text" + self.input_info: Dict = self._validate() + self.img_transforms = transforms.Compose([ + transforms.ToTensor(), + ]) + + def set_transform(self, img_transforms): + self.img_transforms = img_transforms + + def _index_patient(self) -> Dict[str, List[int]]: + """Helper function which indexes the samples by patient_id. + + Will be called in `self.__init__()`. + Returns: + patient_to_index: Dict[str, int], a dict mapping patient_id to a list + of sample indices. + """ + patient_to_index = {} + for idx, sample in enumerate(self.samples): + patient_to_index.setdefault(sample["patient_id"], []).append(idx) + return patient_to_index + + def _validate(self) -> Dict: + """Helper function which gets the input information of each attribute. + + Will be called in `self.__init__()`. + + Returns: + input_info: Dict, a dict whose keys are the same as the keys in the + samples, and values are the corresponding input information + """ + input_info = {} + # get info + input_info["image_path"] = {"type": str, "dim": 2} + input_info["caption"] = {"type": str, "dim": 3} + return input_info + + def __getitem__(self, index) -> Dict: + """Returns a sample by index. + + Returns: + Dict, a dict with patient_id, image_{number}, caption, and other task-specific + attributes as key. Conversion of caption to index/tensor will be done + in the model. + """ + sample = self.samples[index] + for i in range(len(sample['image_path'])): + image_key = f'image_{i+1}' + image = Image.open(sample["image_path"][i]).convert("RGB") + image = self.img_transforms(image) + sample[image_key] = image + return sample + + def stat(self) -> str: + """Returns some statistics of the task-specific dataset.""" + lines = list() + lines.append(f"Statistics of sample dataset:") + lines.append(f"\t- Dataset: {self.dataset_name}") + lines.append(f"\t- Task: {self.task_name}") + lines.append(f"\t- Number of samples: {len(self)}") + num_patients = len(set([sample["patient_id"] for sample in self.samples])) + lines.append(f"\t- Number of patients: {num_patients}") + lines.append(f"\t- Number of records: {len(self)}") + lines.append( + f"\t- Number of samples per patient: {len(self) / num_patients:.4f}" + ) + print("\n".join(lines)) + return "\n".join(lines) if __name__ == "__main__": samples = [ diff --git a/pyhealth/metrics/__init__.py b/pyhealth/metrics/__init__.py index c07c63e4..15e5cb69 100644 --- a/pyhealth/metrics/__init__.py +++ b/pyhealth/metrics/__init__.py @@ -2,3 +2,4 @@ from .drug_recommendation import ddi_rate_score from .multiclass import multiclass_metrics_fn from .multilabel import multilabel_metrics_fn +from .sequence import sequence_metrics_fn diff --git a/pyhealth/metrics/sequence.py b/pyhealth/metrics/sequence.py new file mode 100644 index 00000000..d59ee693 --- /dev/null +++ b/pyhealth/metrics/sequence.py @@ -0,0 +1,31 @@ +from typing import Dict, List, Optional + +from collections import OrderedDict +from pycocoevalcap.bleu.bleu import Bleu +from pycocoevalcap.rouge.rouge import Rouge +from pycocoevalcap.cider.cider import Cider +from pycocoevalcap.meteor.meteor import Meteor + +def sequence_metrics_fn( + y_true: Dict[int,str], + y_generated: Dict[int,str], + metrics: Optional[List[str]] = None +) -> Dict[str, float]: + """ + """ + scorers = [ + (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), + (Meteor(), "METEOR"), + (Rouge(), "ROUGE_L"), + (Cider(), "CIDEr"), + ] + output = {} + for scorer, method in scorers: + score, scores = scorer.compute_score(y_true, y_generated) + if type(score) == list: + for m, s in zip(method, score): + output[m] = s + else: + output[method] = score + + return output \ No newline at end of file diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index af8c8192..4ea4ec84 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -22,3 +22,4 @@ from .grasp import GRASP, GRASPLayer from .stagenet import StageNet, StageNetLayer from .tcn import TCN, TCNLayer +from .wordsat import WordSAT \ No newline at end of file diff --git a/pyhealth/models/base_model.py b/pyhealth/models/base_model.py index a6d844ff..6ee6c25d 100644 --- a/pyhealth/models/base_model.py +++ b/pyhealth/models/base_model.py @@ -10,7 +10,7 @@ from pyhealth.tokenizer import Tokenizer # TODO: add support for regression -VALID_MODE = ["binary", "multiclass", "multilabel"] +VALID_MODE = ["binary", "multiclass", "multilabel","sequence"] class BaseModel(ABC, nn.Module): @@ -232,6 +232,7 @@ def get_loss_function(self) -> Callable: - binary: `F.binary_cross_entropy_with_logits` - multiclass: `F.cross_entropy` - multilabel: `F.binary_cross_entropy_with_logits` + - sequence: `F.cross_entropy` Returns: The default loss function. @@ -242,6 +243,8 @@ def get_loss_function(self) -> Callable: return F.cross_entropy elif self.mode == "multilabel": return F.binary_cross_entropy_with_logits + elif self.mode == "sequence": + return F.cross_entropy else: raise ValueError("Invalid mode: {}".format(self.mode)) diff --git a/pyhealth/models/wordsat.py b/pyhealth/models/wordsat.py new file mode 100644 index 00000000..2407067c --- /dev/null +++ b/pyhealth/models/wordsat.py @@ -0,0 +1,275 @@ +from typing import List, Tuple, Dict, Optional + +import torch +import torch.nn as nn +from torchvision import models +import torch.nn.functional as F + +from pyhealth.datasets import SampleImageCaptionDataset +from pyhealth.models import BaseModel +from pyhealth.datasets.utils import flatten_list + +class WordSATEncoder(nn.Module): + """ + """ + def __init__(self): + super().__init__() + self.densenet121 = models.densenet121(weights='DEFAULT') + self.densenet121.classifier = nn.Identity() + + def forward(self, x): + x = self.densenet121.features(x) + x = F.relu(x) + return x + +class Attention(nn.Module): + """ + """ + def __init__(self, k_size, v_size, affine_size=512): + super().__init__() + self.affine_k = nn.Linear(k_size, affine_size, bias=False) + self.affine_v = nn.Linear(v_size, affine_size, bias=False) + self.affine = nn.Linear(affine_size, 1, bias=False) + + def forward(self, k, v): + # k: batch size x hidden size + # v: batch size x spatial size x hidden size + # z: batch size x spatial size + # TODO other ways of attention? + content_v = self.affine_k(k).unsqueeze(1) + self.affine_v(v) + z = self.affine(torch.tanh(content_v)).squeeze(2) + alpha = torch.softmax(z, dim=1) + context = (v * alpha.unsqueeze(2)).sum(dim=1) + return context, alpha + +class WordSATDecoder(nn.Module): + """ + """ + def __init__( + self, + vocab_size: int, + n_encoder_inputs: int, + feature_dim: int, + embedding_dim: int, + hidden_dim: int, + dropout: int = 0.5 + ): + super().__init__() + self.vocab_size = vocab_size + self.n_encoder_inputs = n_encoder_inputs + self.feature_dim = feature_dim + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + self.dropout = dropout + + self.atten = Attention(self.hidden_dim, self.feature_dim) + self.embed = nn.Embedding(self.vocab_size, self.embedding_dim) + self.init_h = nn.Linear(self.n_encoder_inputs * self.feature_dim, + self.hidden_dim) + self.init_c = nn.Linear(self.n_encoder_inputs * feature_dim, hidden_dim) + self.lstmcell = nn.LSTMCell(self.embedding_dim + + self.n_encoder_inputs * feature_dim, + hidden_dim) + self.fc = nn.Linear(self.hidden_dim, self.vocab_size) + self.dropout = nn.Dropout(self.dropout) + + def forward(self, cnn_features, captions=None, max_len=100): + batch_size = cnn_features[0].size(0) + if captions is not None: + seq_len = captions.size(1) + else: + seq_len = max_len + + cnn_feats_t = [ cnn_feat.view(batch_size, self.feature_dim, -1) \ + .permute(0, 2, 1) + for cnn_feat in cnn_features ] + global_feats = [cnn_feat.mean(dim=(2, 3)) for cnn_feat in cnn_features] + + h = self.init_h(torch.cat(global_feats, dim=1)) + c = self.init_c(torch.cat(global_feats, dim=1)) + + logits = cnn_features[0].new_zeros((batch_size, + seq_len, + self.vocab_size), dtype=torch.float) + + if captions is not None: + embeddings = self.embed(captions) + for t in range(seq_len): + contexts = [self.atten(h, cnn_feat_t)[0] + for cnn_feat_t in cnn_feats_t] + context = torch.cat(contexts, dim=1) + h, c = self.lstmcell(torch.cat((embeddings[:, t], context), + dim=1), + (h, c)) + logits[:, t] = self.fc(self.dropout(h)) + + return logits + + else: + x_t = cnn_features[0].new_full((batch_size,), 1, dtype=torch.long) + for t in range(seq_len): + embedding = self.embed(x_t) + contexts = [self.atten(h, cnn_feat_t)[0] + for cnn_feat_t in cnn_feats_t] + context = torch.cat(contexts, dim=1) + h, c = self.lstmcell(torch.cat((embedding, context), dim=1), + (h, c)) + logit =self.fc(h) + x_t = logit.argmax(dim=1) + logits[:, t] = logit + + return logits.argmax(dim=2) + +class WordSAT(BaseModel): + """Word Show Attend & Tell model. + Argument list of class + """ + + def __init__( + self, + dataset: SampleImageCaptionDataset, + feature_keys: List[str], + label_key: str, + tokenizer: object, + mode: str, + encoder_pretrained_weights: object = None, + encoder_freeze_weights: bool = True, + decoder_maxlen: int = 100, + decoder_embed_dim: int = 256, + decoder_hidden_dim: int = 512, + decoder_feature_dim: int = 1024, + decoder_dropout: float = 0.5, + save_generated_caption: bool = False, + **kwargs + ): + super(WordSAT, self).__init__( + dataset=dataset, + feature_keys=feature_keys, + label_key=label_key, + mode=mode, + ) + + self.encoder = WordSATEncoder() + self.save_generated_caption = save_generated_caption + + if encoder_pretrained_weights: + print(f'Loading encoder pretrained model') + self.encoder.load_state_dict(encoder_pretrained_weights) + + if encoder_freeze_weights: + self.encoder.eval() + + self.decoder_maxlen = decoder_maxlen + self.caption_tokenizer = tokenizer + vocab_size = self.caption_tokenizer.get_vocabulary_size() + self.decoder = WordSATDecoder( + vocab_size, + len(feature_keys), + decoder_feature_dim, + decoder_embed_dim, + decoder_hidden_dim, + decoder_dropout + ) + + def _prepare_batch_images(self,kwargs): + """Prepare images for input. + """ + print(self.n_input_image) + + return images + + def _prepare_batch_captions(self,captions): + """Prepare caption for input. + """ + samples = [] + for caption in captions: + tokens = [] + tokens.extend(flatten_list(caption)) + text = ' '.join(tokens).replace('. .','.') + samples.append([text.split()]) + #print(caption) + x = self.caption_tokenizer.batch_encode_3d(samples) + captions = torch.tensor(x, dtype=torch.long, device=self.device) + masks = torch.sum(captions,dim=1) !=0 + captions = captions.squeeze(1) + + return captions,masks + + def forward(self, **kwargs): + """Forward propagation. + """ + patient_ids = kwargs['patient_id'] + + image_features = [ feature for feature in self.feature_keys + if 'image_' in feature] + + images = [ torch.stack(kwargs[image_feature], 0) + for image_feature in image_features + ] + + cnn_features = [self.encoder(image.to(self.device)) for image in images] + output = {} + if self.training: + captions,masks = self._prepare_batch_captions(kwargs[self.label_key]) + logits = self.decoder(cnn_features, captions[:,:-1], + self.decoder_maxlen) + logits = logits.permute(0, 2, 1).contiguous() + captions = captions[:, 1:].contiguous() + masks = masks[:, 1:].contiguous() + + loss = self.get_loss_function()(logits, captions) + loss = loss.masked_select(masks).mean() + + output["loss"] = loss + else: + output["y_generated"] = self._forward_inference(patient_ids, + cnn_features + ) + output["y_true"] = self._forward_get_ground_truths(patient_ids, + kwargs[self.label_key] + ) + return output + + def _forward_inference(self,patient_ids,cnn_features): + """ + """ + generated_results = {} + for idx, patient_id in enumerate(patient_ids): + generated_results[patient_id] = [""] + cnn_feature = [cnn_feat[idx].unsqueeze(0) + for cnn_feat in cnn_features] + pred = self.decoder(cnn_feature, None, self.decoder_maxlen)[0] + pred = pred.detach().cpu() + pred_tokens = self.caption_tokenizer \ + .convert_indices_to_tokens(pred.tolist()) + generated_results[patient_id] = [""] + words = [] + for token in pred_tokens: + if token == '' or token == '': + continue + if token == '': + break + words.append(token) + + generated_results[patient_id][0] = " ".join(words) + + return generated_results + + def _forward_get_ground_truths(self,patient_ids,captions): + """ + """ + ground_truths = {} + for idx, caption in enumerate(captions): + ground_truths[patient_ids[idx]] = [""] + tokens = [] + tokens.extend(flatten_list(caption)) + ground_truths[patient_ids[idx]][0] = ' '.join(tokens) \ + .replace('. .','.') \ + .replace("","") \ + .replace("","") \ + .strip() + + return ground_truths + + + diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index b8207478..e98057c4 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -27,5 +27,5 @@ sleep_staging_isruc_fn ) from .xray_report_generation import ( - biview_onesent_fn + biview_multisent_fn ) diff --git a/pyhealth/tasks/xray_report_generation.py b/pyhealth/tasks/xray_report_generation.py index c42b88bd..ee1a8677 100644 --- a/pyhealth/tasks/xray_report_generation.py +++ b/pyhealth/tasks/xray_report_generation.py @@ -1,18 +1,39 @@ import os +import string -def biview_onesent_fn(patient): +def biview_multisent_fn(patient): """ Processes single patient for xray report generation""" sample = {} - sample['frontal_image_path'] = None - sample['lateral_image_path'] = None - img_root = '/srv/local/data/IU_XRay/images/images_normalized' + image_path = [] + img_root = "/srv/local/data/IU_XRay/images/images_normalized" for data in patient: - sample['patient_id'] = data['patient_id'] - sample['report'] = [data['impression']+data['findings']] - if data['view'] == 'frontal': - sample['frontal_image_path'] = os.path.join(img_root, data['path']) - if data['view'] == 'lateral': - sample['lateral_image_path'] = os.path.join(img_root, data['path']) - patient = [sample] - return patient \ No newline at end of file + patient_id = data["patient_id"] + if data["view"] == "frontal": + image_path.insert(0,os.path.join(img_root, data["path"])) + if data["view"] == "lateral": + image_path.append(os.path.join(img_root, data["path"])) + + impression = data["impression"] + findings = data["findings"] + report = f"{impression} . {findings}" + + sample["patient_id"] = patient_id + sample["image_path"] = image_path + + sents = report.lower().split(".") + sents = [sent for sent in sents if len(sent.strip()) > 1] + sample["caption"] = [] + for isent, sent in enumerate(sents): + tokens = sent.translate(str.maketrans("", "", string.punctuation)) \ + .strip() \ + .split() + sample["caption"].append([".", *[token for token in tokens],"."]) + + if sample["caption"] == []: + sample["caption"] = [[" "," "]] + + sample["caption"][0][0] = "" + sample["caption"][-1].append("") + + return [sample] \ No newline at end of file diff --git a/pyhealth/trainer.py b/pyhealth/trainer.py index fd290fc0..d901f07c 100644 --- a/pyhealth/trainer.py +++ b/pyhealth/trainer.py @@ -12,7 +12,7 @@ from tqdm.autonotebook import trange from pyhealth.metrics import (binary_metrics_fn, multiclass_metrics_fn, - multilabel_metrics_fn) + multilabel_metrics_fn,sequence_metrics_fn) from pyhealth.utils import create_directory logger = logging.getLogger(__name__) @@ -44,6 +44,8 @@ def get_metrics_fn(mode: str) -> Callable: return multiclass_metrics_fn elif mode == "multilabel": return multilabel_metrics_fn + elif mode == "sequence": + return sequence_metrics_fn else: raise ValueError(f"Mode {mode} is not supported") @@ -176,6 +178,7 @@ def train( # epoch training loop for epoch in range(epochs): + self.current_epoch = epoch training_loss = [] self.model.zero_grad() self.model.train() @@ -257,6 +260,8 @@ def inference(self, dataloader, additional_outputs=None) -> Dict[str, float]: loss_mean: Mean loss over batches. additional_outputs (only if requested): Dict of additional results. """ + if self.model.mode == "sequence": + return self.inference_sequence(dataloader) loss_all = [] y_true_all = [] y_prob_all = [] @@ -298,7 +303,8 @@ def evaluate(self, dataloader) -> Dict[str, float]: mode = self.model.mode metrics_fn = get_metrics_fn(mode) scores = metrics_fn(y_true_all, y_prob_all, metrics=self.metrics) - scores["loss"] = loss_mean + if mode != "sequence": + scores["loss"] = loss_mean return scores def save_ckpt(self, ckpt_path: str) -> None: @@ -313,6 +319,33 @@ def load_ckpt(self, ckpt_path: str) -> None: self.model.load_state_dict(state_dict) return + def inference_sequence(self, dataloader) -> Dict[int, str]: + """Model inference + """ + y_true_all = {} + y_pred_all = {} + for data in tqdm(dataloader, desc="Evaluation"): + self.model.eval() + with torch.no_grad(): + output = self.model(**data) + y_true = output["y_true"] + y_generated = output["y_generated"] + for key in y_generated.keys(): + y_true_all[key] = y_true[key] + y_pred_all[key] = y_generated[key] + + if self.model.save_generated_caption: + with open(os.path.join(self.exp_path, + f'val_e{self.current_epoch}.csv'),'w') as f1: + with open(os.path.join(self.exp_path, + 'val_gts.csv'), 'w') as f2: + for patient_id in y_pred_all.keys(): + f1.write(y_pred_all[patient_id][0] + '\n') + f2.write(y_true_all[patient_id][0] + '\n') + + + return y_true_all, y_pred_all, 0 + if __name__ == "__main__": import torch From 4247674e853a53618a29bcab440f623fb73dcd2c Mon Sep 17 00:00:00 2001 From: samarthkeshari Date: Sun, 23 Apr 2023 09:54:22 -0700 Subject: [PATCH 05/11] refactored code and added documentation --- .vscode/settings.json | 5 + .../datasets/base_image_caption_dataset.py | 52 ++-- pyhealth/datasets/sample_dataset.py | 26 +- pyhealth/models/wordsat.py | 244 ++++++++++++------ pyhealth/tasks/xray_report_generation.py | 77 ++++-- 5 files changed, 259 insertions(+), 145 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..10434d4c --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "editor.rulers": [ + 79 + ] +} \ No newline at end of file diff --git a/pyhealth/datasets/base_image_caption_dataset.py b/pyhealth/datasets/base_image_caption_dataset.py index 2e472100..508c11d4 100644 --- a/pyhealth/datasets/base_image_caption_dataset.py +++ b/pyhealth/datasets/base_image_caption_dataset.py @@ -13,18 +13,19 @@ INFO_MSG = """ dataset.patients: - key: patient id - - value: a dict of image paths, captions, and other information + - value: a dict of image paths, caption, and other information """ class BaseImageCaptionDataset(ABC): """Abstract base Image Caption Generation dataset class. - This abstract class defines a uniform interface for all + This abstract class defines a uniform interface for all image caption generation datasets. - Each specific dataset will be a subclass of this abstract class, which can then - be converted to samples dataset for different tasks by calling `self.set_task()`. + Each specific dataset will be a subclass of this abstract class, which can + then be converted to samples dataset for different tasks by calling + `self.set_task()`. Args: root: root directory of the raw data (should contain many csv files). @@ -32,7 +33,8 @@ class BaseImageCaptionDataset(ABC): dev: whether to enable dev mode (only use a small subset of the data). Default is False. refresh_cache: whether to refresh the cache; if true, the dataset will - be processed from scratch and the cache will be updated. Default is False. + be processed from scratch and the cache will be updated. + Default is False. """ def __init__( @@ -50,16 +52,16 @@ def __init__( # TODO: dev seems unnecessary for image and signal? self.dev = dev if dev: - logger.warning("WARNING: dev has no effect for image caption generation datasets.") + logger.warning("WARNING: dev has no effect \ + for image caption generation datasets.") # TODO: refresh_cache seems unnecessary for image and signal? self.refresh_cache = refresh_cache if refresh_cache: - logger.warning("WARNING: refresh_cache has no effect for image caption generation datasets.") + logger.warning("WARNING: refresh_cache has no effect \ + for image caption generation datasets.") - self.metadata = pd.read_json(os.path.join(root, "metadata.jsonl"), lines=True) - #self.metadata["path"] = self.metadata["path"].apply( - # lambda x: os.path.join(root, x) - #) + self.metadata = pd.read_json(os.path.join(root, + "metadata.jsonl"), lines=True) if "patient_id" not in self.metadata.columns: # no patient_id in metadata, sequentially assign patient_id self.metadata["patient_id"] = self.metadata.index @@ -99,7 +101,8 @@ def set_task( task_fn: Callable, task_name: Optional[str] = None, ) -> SampleImageCaptionDataset: - """Processes the base dataset to generate the task-specific sample dataset. + """Processes the base dataset to generate the task-specific + sample dataset. This function should be called by the user after the base dataset is initialized. It will iterate through all patients in the base dataset @@ -107,9 +110,10 @@ def set_task( Args: task_fn: a function that takes a single patient and returns a - list of samples (each sample is a dict with patient_id, visit_id, - and other task-specific attributes as key). The samples will be - concatenated to form the sample dataset. + list of samples (each sample is a dict with patient_id, + image_path_list, caption and other task-specific attributes + as key). The samples will be concatenated to form the + sample dataset. task_name: the name of the task. If None, the name of the task function will be used. @@ -118,12 +122,16 @@ def set_task( Note: In `task_fn`, a patient may have one or multiple images associated - to a caption, for e.g. a +patient can have single report associated - to multiple xrays from diffrent views that may be combined to have - a single sample ({'patient_id':1, 'frontal_image':'img1', - 'lateral_image': 'img2','report':'text}) - Patients can also be excluded from the task dataset by returning - an empty list. + to a caption, for e.g. a patient can have single report + for xrays taken from diffrent views that may be combined to + have a single sample such as + ( + {'patient_id': 1, + 'image_path_list': [frontal_img_path,lateral_img_path], + 'caption': 'report_text'} + ) + Patients can also be excluded from the task dataset by + returning an empty list. """ if task_name is None: task_name = task_fn.__name__ @@ -135,7 +143,7 @@ def set_task( for patient_id, patient in tqdm( self.patients.items(), desc=f"Generating samples for {task_name}"): samples.extend(task_fn(patient)) - + sample_dataset = SampleImageCaptionDataset( samples, dataset_name=self.dataset_name, diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 95d6c1e9..79ae5401 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -309,7 +309,7 @@ def _validate(self) -> Dict: - a list of vectors - a list of list of codes - a list of list of vectors - Note that a value is either float, int, or str; a vector is a list of float + Note that a value is either float, int, or str; a vector is a list of float or int; and a code is str. """ # record input information for each key @@ -349,7 +349,7 @@ def _validate(self) -> Dict: ] """ - 4.2. Check type: the basic type of each element should be float, + 4.2. Check type: the basic type of each element should be float, int, or str. """ types = set([type(v) for v in flattened_values]) @@ -514,9 +514,9 @@ class SampleImageCaptionDataset(SampleBaseDataset): """ def __init__( - self, - samples: List[Dict], - dataset_name: str = "", + self, + samples: List[Dict], + dataset_name: str = "", task_name: str = "" ): super().__init__(samples, dataset_name, task_name) @@ -535,8 +535,8 @@ def _index_patient(self) -> Dict[str, List[int]]: Will be called in `self.__init__()`. Returns: - patient_to_index: Dict[str, int], a dict mapping patient_id to a list - of sample indices. + patient_to_index: Dict[str, int], a dict mapping patient_id to a + list of sample indices. """ patient_to_index = {} for idx, sample in enumerate(self.samples): @@ -554,7 +554,7 @@ def _validate(self) -> Dict: """ input_info = {} # get info - input_info["image_path"] = {"type": str, "dim": 2} + #input_info["image_path"] = {"type": str, "dim": 2} input_info["caption"] = {"type": str, "dim": 3} return input_info @@ -562,14 +562,14 @@ def __getitem__(self, index) -> Dict: """Returns a sample by index. Returns: - Dict, a dict with patient_id, image_{number}, caption, and other task-specific - attributes as key. Conversion of caption to index/tensor will be done - in the model. + Dict, a dict with patient_id, image_{number}, caption, and other + task-specific attributes as key. Conversion of caption to + index/tensor will be done in the model. """ sample = self.samples[index] - for i in range(len(sample['image_path'])): + for i in range(len(sample['image_path_list'])): image_key = f'image_{i+1}' - image = Image.open(sample["image_path"][i]).convert("RGB") + image = Image.open(sample["image_path_list"][i]).convert("RGB") image = self.img_transforms(image) sample[image_key] = image return sample diff --git a/pyhealth/models/wordsat.py b/pyhealth/models/wordsat.py index 2407067c..9883fcad 100644 --- a/pyhealth/models/wordsat.py +++ b/pyhealth/models/wordsat.py @@ -7,6 +7,7 @@ from pyhealth.datasets import SampleImageCaptionDataset from pyhealth.models import BaseModel +from pyhealth.tokenizer import Tokenizer from pyhealth.datasets.utils import flatten_list class WordSATEncoder(nn.Module): @@ -22,14 +23,14 @@ def forward(self, x): x = F.relu(x) return x -class Attention(nn.Module): +class WordSATAttention(nn.Module): """ """ - def __init__(self, k_size, v_size, affine_size=512): + def __init__(self, k_size, v_size, affine_dim=512): super().__init__() - self.affine_k = nn.Linear(k_size, affine_size, bias=False) - self.affine_v = nn.Linear(v_size, affine_size, bias=False) - self.affine = nn.Linear(affine_size, 1, bias=False) + self.affine_k = nn.Linear(k_size, affine_dim, bias=False) + self.affine_v = nn.Linear(v_size, affine_dim, bias=False) + self.affine = nn.Linear(affine_dim, 1, bias=False) def forward(self, k, v): # k: batch size x hidden size @@ -46,11 +47,12 @@ class WordSATDecoder(nn.Module): """ """ def __init__( - self, + self, + attention: WordSATAttention, vocab_size: int, n_encoder_inputs: int, - feature_dim: int, - embedding_dim: int, + feature_dim: int, + embedding_dim: int, hidden_dim: int, dropout: int = 0.5 ): @@ -61,14 +63,14 @@ def __init__( self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.dropout = dropout - - self.atten = Attention(self.hidden_dim, self.feature_dim) + self.attend = attention self.embed = nn.Embedding(self.vocab_size, self.embedding_dim) self.init_h = nn.Linear(self.n_encoder_inputs * self.feature_dim, self.hidden_dim) - self.init_c = nn.Linear(self.n_encoder_inputs * feature_dim, hidden_dim) - self.lstmcell = nn.LSTMCell(self.embedding_dim + - self.n_encoder_inputs * feature_dim, + self.init_c = nn.Linear(self.n_encoder_inputs * feature_dim, + hidden_dim) + self.lstmcell = nn.LSTMCell(self.embedding_dim + + self.n_encoder_inputs * feature_dim, hidden_dim) self.fc = nn.Linear(self.hidden_dim, self.vocab_size) self.dropout = nn.Dropout(self.dropout) @@ -82,23 +84,23 @@ def forward(self, cnn_features, captions=None, max_len=100): cnn_feats_t = [ cnn_feat.view(batch_size, self.feature_dim, -1) \ .permute(0, 2, 1) - for cnn_feat in cnn_features ] + for cnn_feat in cnn_features ] global_feats = [cnn_feat.mean(dim=(2, 3)) for cnn_feat in cnn_features] - + h = self.init_h(torch.cat(global_feats, dim=1)) c = self.init_c(torch.cat(global_feats, dim=1)) - logits = cnn_features[0].new_zeros((batch_size, - seq_len, - self.vocab_size), dtype=torch.float) + logits = cnn_features[0].new_zeros((batch_size, + seq_len, + self.vocab_size),dtype=torch.float) - if captions is not None: + if captions: embeddings = self.embed(captions) for t in range(seq_len): - contexts = [self.atten(h, cnn_feat_t)[0] + contexts = [self.attend(h, cnn_feat_t)[0] for cnn_feat_t in cnn_feats_t] context = torch.cat(contexts, dim=1) - h, c = self.lstmcell(torch.cat((embeddings[:, t], context), + h, c = self.lstmcell(torch.cat((embeddings[:, t], context), dim=1), (h, c)) logits[:, t] = self.fc(self.dropout(h)) @@ -109,12 +111,12 @@ def forward(self, cnn_features, captions=None, max_len=100): x_t = cnn_features[0].new_full((batch_size,), 1, dtype=torch.long) for t in range(seq_len): embedding = self.embed(x_t) - contexts = [self.atten(h, cnn_feat_t)[0] + contexts = [self.atten(h, cnn_feat_t)[0] for cnn_feat_t in cnn_feats_t] context = torch.cat(contexts, dim=1) - h, c = self.lstmcell(torch.cat((embedding, context), dim=1), + h, c = self.lstmcell(torch.cat((embedding, context), dim=1), (h, c)) - logit =self.fc(h) + logit = self.fc(h) x_t = logit.argmax(dim=1) logits[:, t] = logit @@ -122,60 +124,165 @@ def forward(self, cnn_features, captions=None, max_len=100): class WordSAT(BaseModel): """Word Show Attend & Tell model. - Argument list of class + This model is based on Show, Attend, and Tell (SAT) paper. The model uses + convolutional neural networks (CNNs) to encode the image and a + recurrent neural network (RNN) with an attention mechanism to generate + the corresponding caption. + + The model consists of three main components: + - Encoder: The encoder is a CNN that extracts a fixed-length feature + vectors from the input image. + - Attention Mechanism: The attention mechanism is used to select + relevant parts of the image at each time step of the RNN, to + generate the next word in the caption. It computes a set of + attention weights based on the current hidden state of the RNN and + the feature vectors from the CNN, which are then used to compute a + weighted average of the feature vectors. + - Decoder: The decoder is a language model implemented as an RNN that + takes as input the attention-weighted feature vector and generates + a sequence of words, one at a time. + + Args: + dataset: the dataset to train the model. + n_input_images: number of images passed as input to each sample. + label_key: key in the samples to use as label (e.g., "caption"). + tokenizer: pyhealth tokenizer instance created using sample texts. + encoder_pretrained_weights: pretrained state dictionary for encoder. + Default is None. + encoder_freeze_weights: freeze encoder weights so that they are not + updated during training. This is useful when the encoder is trained + separately as a classifier. Default is True. + decder_maxlen: maximum caption length used during training or generated + during inference. Default is 100. + decoder_embed_dim: decoder embedding dimesion. Default is 256. + decoder_hidden_dim: decoder hidden state dimension. Default is 512. + decoder_feaure_dim: decoder input cell state dimension. + Default is 1024 + decoder_dropout: decoder dropout rate [0,1]. Default is 0.5 + attention_affine_dim: output dimension of affine layer in attention. + Default is 512. + save_generated_caption: save the generated caption during training. + This is used for evaluating the quality of generated captions. + Default is False. """ - + def __init__( self, dataset: SampleImageCaptionDataset, - feature_keys: List[str], + n_input_images: int, label_key: str, - tokenizer: object, - mode: str, - encoder_pretrained_weights: object = None, + tokenizer: Tokenizer, + encoder_pretrained_weights: Dict[str,float] = None, encoder_freeze_weights: bool = True, decoder_maxlen: int = 100, decoder_embed_dim: int = 256, decoder_hidden_dim: int = 512, decoder_feature_dim: int = 1024, decoder_dropout: float = 0.5, + attention_affine_dim: int = 512, save_generated_caption: bool = False, **kwargs ): super(WordSAT, self).__init__( dataset=dataset, - feature_keys=feature_keys, + feature_keys=[f'image_{i+1}' for i in range(n_input_images)], label_key=label_key, - mode=mode, + mode="sequence", ) - - self.encoder = WordSATEncoder() + self.n_input_images = n_input_images self.save_generated_caption = save_generated_caption + # Encoder component + self.encoder = WordSATEncoder() if encoder_pretrained_weights: print(f'Loading encoder pretrained model') self.encoder.load_state_dict(encoder_pretrained_weights) - if encoder_freeze_weights: self.encoder.eval() - + + # Attention component + self.attention = WordSATAttention(decoder_hidden_dim, + decoder_feature_dim, + attention_affine_dim) + + # Decoder component self.decoder_maxlen = decoder_maxlen self.caption_tokenizer = tokenizer - vocab_size = self.caption_tokenizer.get_vocabulary_size() - self.decoder = WordSATDecoder( - vocab_size, - len(feature_keys), - decoder_feature_dim, - decoder_embed_dim, - decoder_hidden_dim, - decoder_dropout - ) + vocab_size = self.caption_tokenizer.get_vocabulary_size() + self.decoder = WordSATDecoder(self.attention, + vocab_size, + n_input_images, + decoder_feature_dim, + decoder_embed_dim, + decoder_hidden_dim, + decoder_dropout + ) + + def forward(self, **kwargs): + """Forward propagation. + + The features `kwargs[self.feature_keys]` is a list of feature keys + associated to every input image. + + The label `kwargs[self.label_key]` is a key of the report caption + for each patient. + + Args: + **kwargs: keyword arguments for the model. The keys must contain + all the feature keys and the label key. + + Returns: + A dictionary with the following keys: + loss: a scalar tensor representing the loss. + y_generated: a dictionary list of text representing the generated caption. + The list contains only one element. + y_true: a list of text representing the true caption. + The list contains only one element. + """ + # Initialize the output + output = {"loss": None,"y_generated": "","y_true": ""} + + # Get list of patient_ids + patient_ids = kwargs["patient_id"] + + # Get CNN features + images = self._prepare_batch_image(kwargs) + cnn_features = [self.encoder(image.to(self.device)) + for image in images] + + if self.training: + # Get caption indexes and masks + captions,masks=self._prepare_batch_captions(kwargs[self.label_key]) + + # Perform predictions + logits = self.decoder(cnn_features, captions[:,:-1], + self.decoder_maxlen) + logits = logits.permute(0, 2, 1).contiguous() + captions = captions[:, 1:].contiguous() + masks = masks[:, 1:].contiguous() + + # Compute loss + loss = self.get_loss_function()(logits, captions) + loss = loss.masked_select(masks).mean() + output["loss"] = loss + else: + output["y_generated"] = self._forward_inference(patient_ids, + cnn_features) + output["y_true"] = self._forward_ground_truths(patient_ids, + kwargs[self.label_key]) + return output def _prepare_batch_images(self,kwargs): """Prepare images for input. + Args: + kwargs: keyword arguments for the model. + Returns: + images: a list of input images represented as tensors. Every tensor + in the list has shape [batch_size,3,image_size,image_size] """ - print(self.n_input_image) - + images = [torch.stack(kwargs[image_feature], 0) + for image_feature in self.feature_keys] + return images def _prepare_batch_captions(self,captions): @@ -192,43 +299,8 @@ def _prepare_batch_captions(self,captions): captions = torch.tensor(x, dtype=torch.long, device=self.device) masks = torch.sum(captions,dim=1) !=0 captions = captions.squeeze(1) - - return captions,masks - def forward(self, **kwargs): - """Forward propagation. - """ - patient_ids = kwargs['patient_id'] - - image_features = [ feature for feature in self.feature_keys - if 'image_' in feature] - - images = [ torch.stack(kwargs[image_feature], 0) - for image_feature in image_features - ] - - cnn_features = [self.encoder(image.to(self.device)) for image in images] - output = {} - if self.training: - captions,masks = self._prepare_batch_captions(kwargs[self.label_key]) - logits = self.decoder(cnn_features, captions[:,:-1], - self.decoder_maxlen) - logits = logits.permute(0, 2, 1).contiguous() - captions = captions[:, 1:].contiguous() - masks = masks[:, 1:].contiguous() - - loss = self.get_loss_function()(logits, captions) - loss = loss.masked_select(masks).mean() - - output["loss"] = loss - else: - output["y_generated"] = self._forward_inference(patient_ids, - cnn_features - ) - output["y_true"] = self._forward_get_ground_truths(patient_ids, - kwargs[self.label_key] - ) - return output + return captions,masks def _forward_inference(self,patient_ids,cnn_features): """ @@ -236,7 +308,7 @@ def _forward_inference(self,patient_ids,cnn_features): generated_results = {} for idx, patient_id in enumerate(patient_ids): generated_results[patient_id] = [""] - cnn_feature = [cnn_feat[idx].unsqueeze(0) + cnn_feature = [cnn_feat[idx].unsqueeze(0) for cnn_feat in cnn_features] pred = self.decoder(cnn_feature, None, self.decoder_maxlen)[0] pred = pred.detach().cpu() @@ -250,12 +322,12 @@ def _forward_inference(self,patient_ids,cnn_features): if token == '': break words.append(token) - + generated_results[patient_id][0] = " ".join(words) return generated_results - def _forward_get_ground_truths(self,patient_ids,captions): + def _forward_ground_truths(self,patient_ids,captions): """ """ ground_truths = {} diff --git a/pyhealth/tasks/xray_report_generation.py b/pyhealth/tasks/xray_report_generation.py index ee1a8677..753e7327 100644 --- a/pyhealth/tasks/xray_report_generation.py +++ b/pyhealth/tasks/xray_report_generation.py @@ -2,38 +2,67 @@ import string def biview_multisent_fn(patient): - """ Processes single patient for xray report generation""" + """ Processes single patient for X-ray report generation task + + Xray report generation aims a automatically generating the diagnosis report + from X-ray images taken from two viewpoints namely - frontal and lateral. + + An X-ray report generally consists of following elements + - History, describing the reason for the X-ray + - Technique, describing the type of X-ray performed - frontal,lateral + - Findings, describing the observations made by the radiologist who + interpreted the X-ray images + - Impression, describing the overall interpretation of X-ray results + + Args: + patient: a dictionary of patient X-ray report with following keys + - patient_id: type(int) + unique identifier for patient + - frontal_img_path: type(str) + path to frontal X-ray image + - lateral_img_path: type(str) + path to lateral X-ray image + - findings: type(str) + text of X-ray report findings + - impression: type(str) + text of X-ray report impression + + Returns: + sample: a list of one sample, each sample is a dict with following keys + - patient_id: type(int) + unique identifier for patient + - image_path_list: type(List) + list of frontal and lateral image paths + - caption: type(List[List]) + nested list of sentences,where each inner list represents a + single sentence in the X-ray report text(formed by + concatenating impression and findings). + + Note: special tokens "" and "", are added to the begining of + first and end of last sentence. These are mandatory for the task. + """ sample = {} - image_path = [] - img_root = "/srv/local/data/IU_XRay/images/images_normalized" - - for data in patient: - patient_id = data["patient_id"] - if data["view"] == "frontal": - image_path.insert(0,os.path.join(img_root, data["path"])) - if data["view"] == "lateral": - image_path.append(os.path.join(img_root, data["path"])) - - impression = data["impression"] - findings = data["findings"] - report = f"{impression} . {findings}" - - sample["patient_id"] = patient_id - sample["image_path"] = image_path + report = f"{patient['impression']} . {patient['findings']}" + caption = [] sents = report.lower().split(".") sents = [sent for sent in sents if len(sent.strip()) > 1] - sample["caption"] = [] + for isent, sent in enumerate(sents): tokens = sent.translate(str.maketrans("", "", string.punctuation)) \ .strip() \ .split() - sample["caption"].append([".", *[token for token in tokens],"."]) - - if sample["caption"] == []: - sample["caption"] = [[" "," "]] + caption.append([".", *[token for token in tokens],"."]) - sample["caption"][0][0] = "" - sample["caption"][-1].append("") + if caption == []: + caption = [["",""]] + else: + caption[0][0] = "" + caption[-1].append("") + sample["patient_id"] = int(patient["patient_id"]) + sample["image_path_list"] = [ patient["frontal_img_path"], + patient["lateral_img_path"], + ] + sample["caption"] = caption return [sample] \ No newline at end of file From 0a2e74147d8bc0c6e2b3b0860892d443cc714c67 Mon Sep 17 00:00:00 2001 From: samarthkeshari Date: Sun, 23 Apr 2023 10:40:07 -0700 Subject: [PATCH 06/11] updated documentation --- pyhealth/models/wordsat.py | 40 +++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/pyhealth/models/wordsat.py b/pyhealth/models/wordsat.py index 9883fcad..158fd29c 100644 --- a/pyhealth/models/wordsat.py +++ b/pyhealth/models/wordsat.py @@ -234,13 +234,18 @@ def forward(self, **kwargs): Returns: A dictionary with the following keys: loss: a scalar tensor representing the loss. - y_generated: a dictionary list of text representing the generated caption. - The list contains only one element. - y_true: a list of text representing the true caption. - The list contains only one element. + y_generated: a dictionary with following key + - "patient_id": list of text representing the generated + text. + e.g. + {123: ["generated text"], 456: ["generated text"]} + y_true: a dictionary with following key + - "patient_id": list of text representing the true text. + e.g. + {123: ["true text"], 456: ["true text"]} """ # Initialize the output - output = {"loss": None,"y_generated": "","y_true": ""} + output = {"loss": None,"y_generated": [""],"y_true": [""]} # Get list of patient_ids patient_ids = kwargs["patient_id"] @@ -286,21 +291,34 @@ def _prepare_batch_images(self,kwargs): return images def _prepare_batch_captions(self,captions): - """Prepare caption for input. + """Prepare caption idx for input. + + Args: + captions: list of captions. Each caption is a list of list, where + each list represents a sentence in the caption. + Following is an example of a caption + [ + ["","first","sentence","."], + [".", "second", "sentence",".", ""] + ] + Returns: + captions_idx: an int tensor of size [batch_size,max_caption_length] + masks: a bool tensor of size [batch_size,max_caption_length] """ samples = [] + # Combine all sentences in each caption to create a single sentence for caption in captions: tokens = [] tokens.extend(flatten_list(caption)) text = ' '.join(tokens).replace('. .','.') samples.append([text.split()]) - #print(caption) + x = self.caption_tokenizer.batch_encode_3d(samples) - captions = torch.tensor(x, dtype=torch.long, device=self.device) - masks = torch.sum(captions,dim=1) !=0 - captions = captions.squeeze(1) + captions_idx = torch.tensor(x, dtype=torch.long, device=self.device) + masks = torch.sum(captions_idx,dim=1) !=0 + captions_idx = captions_idx.squeeze(1) - return captions,masks + return captions_idx,masks def _forward_inference(self,patient_ids,cnn_features): """ From 00f5eb5c3891305a8d64086a4db0bca5890f851b Mon Sep 17 00:00:00 2001 From: samarthkeshari Date: Mon, 24 Apr 2023 15:39:23 -0500 Subject: [PATCH 07/11] added sentsat model --- pyhealth/models/__init__.py | 3 +- pyhealth/models/base_model.py | 5 +- pyhealth/models/sentsat.py | 508 +++++++++++++++++++++++ pyhealth/models/wordsat.py | 203 ++++++--- pyhealth/tasks/xray_report_generation.py | 3 +- pyhealth/trainer.py | 5 +- test.py | 23 +- 7 files changed, 675 insertions(+), 75 deletions(-) create mode 100644 pyhealth/models/sentsat.py diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4ea4ec84..35662eb3 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -22,4 +22,5 @@ from .grasp import GRASP, GRASPLayer from .stagenet import StageNet, StageNetLayer from .tcn import TCN, TCNLayer -from .wordsat import WordSAT \ No newline at end of file +from .wordsat import WordSAT, WordSATEncoder,WordSATDecoder, WordSATAttention +from .sentsat import SentSAT \ No newline at end of file diff --git a/pyhealth/models/base_model.py b/pyhealth/models/base_model.py index 6ee6c25d..a905e5d8 100644 --- a/pyhealth/models/base_model.py +++ b/pyhealth/models/base_model.py @@ -22,7 +22,7 @@ class BaseModel(ABC, nn.Module): feature_keys: list of keys in samples to use as features, e.g. ["conditions", "procedures"]. label_key: key in samples to use as label (e.g., "drugs"). - mode: one of "binary", "multiclass", or "multilabel". + mode: one of "binary", "multiclass", "multilabel", or sequence. """ def __init__( @@ -31,6 +31,7 @@ def __init__( feature_keys: List[str], label_key: str, mode: str, + save_generated_caption: bool = False ): super(BaseModel, self).__init__() assert mode in VALID_MODE, f"mode must be one of {VALID_MODE}" @@ -38,6 +39,8 @@ def __init__( self.feature_keys = feature_keys self.label_key = label_key self.mode = mode + if mode == "sequence": + self.save_generated_caption = save_generated_caption # used to query the device of the model self._dummy_param = nn.Parameter(torch.empty(0)) return diff --git a/pyhealth/models/sentsat.py b/pyhealth/models/sentsat.py new file mode 100644 index 00000000..45f99ad8 --- /dev/null +++ b/pyhealth/models/sentsat.py @@ -0,0 +1,508 @@ +from typing import List, Tuple, Dict, Optional + +import torch +import torch.nn as nn +from torchvision import models +import torch.nn.functional as F + +from pyhealth.datasets import SampleImageCaptionDataset +from pyhealth.models import BaseModel +from pyhealth.tokenizer import Tokenizer +from pyhealth.datasets.utils import flatten_list + +class SentSATEncoder(nn.Module): + """ SAT CNN(Densenet121) Encoder model""" + def __init__(self): + super().__init__() + self.densenet121 = models.densenet121(weights='DEFAULT') + self.densenet121.classifier = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward propagation. + Extract fixed-length feature vectors from the input image. + + Args: + x: A tensor of tranfomed image of size + [batch_size,3,224,224] + Return: + x: A tensor of image feature vectors of size + [batch_size,1024,16,16] + """ + x = self.densenet121.features(x) + x = F.relu(x) + return x + +class SentSATAttention(nn.Module): + """SAT Attention Module + + Computes a set of attention weights based on the current hidden state of + the RNN and the feature vectors from the CNN, which are then used to + compute a weighted average of the feature vectors. + + Args: + k_size: key vector size + v_size: value vector size + affine_dim: affine dimension. Default is 512 + """ + def __init__( + self, + k_size: int, + v_size: int, + affine_dim: int =512): + super().__init__() + self.affine_k = nn.Linear(k_size, affine_dim, bias=False) + self.affine_v = nn.Linear(v_size, affine_dim, bias=False) + self.affine = nn.Linear(affine_dim, 1, bias=False) + + def forward( + self, + k: torch.Tensor, + v: torch.Tensor) -> (torch.Tensor,torch.Tensor): + """Forward propagation + + Args: + k: a tensor of size [batch_size, hidden_dim] + v: a tensor of size [batch_size, spatial_size, hidden_dim] + + Returns: + context: a tensor of size [batch_size, feature_dim] + alpha: a tensor of size [batch_size, spatial_size] + """ + content_v = self.affine_k(k).unsqueeze(1) + self.affine_v(v) + # z: batch size x spatial size + z = self.affine(torch.tanh(content_v)).squeeze(2) + alpha = torch.softmax(z, dim=1) + context = (v * alpha.unsqueeze(2)).sum(dim=1) + + return context, alpha + +class SentSATDecoder(nn.Module): + """ Word SAT decoder model for one sentence + + An LSTM based model that takes as input the attention-weighted feature + vector and generates a sequence of words, one at a time. + + Args: + attention: attention module instance + vocab_size: vocabulary size + n_encoder_inputs: number of image inputs given to the encoder + feature_dim: encoder output feature dimesion + hidden_dim: LSTM hidden dimension + dropout: dropout rate between [0,1] + """ + def __init__( + self, + attention: object, + vocab_size: int, + n_encoder_inputs: int, + feature_dim: int, + embedding_dim: int, + hidden_dim: int, + dropout: int = 0.5 + ): + super().__init__() + self.vocab_size = vocab_size + self.n_encoder_inputs = n_encoder_inputs + self.feature_dim = feature_dim + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + + self.attend = attention + self.embed = nn.Embedding(self.vocab_size, self.embedding_dim) + self.init_h = nn.Linear(self.n_encoder_inputs * self.feature_dim, + self.hidden_dim) + self.init_c = nn.Linear(self.n_encoder_inputs * self.feature_dim, + hidden_dim) + self.sent_lstm = nn.LSTMCell(self.n_encoder_inputs * self.feature_dim, + hidden_dim) + + self.word_lstm = nn.LSTMCell(self.embedding_dim + self.hidden_dim + + self.n_encoder_inputs * feature_dim, + hidden_dim) + self.fc = nn.Linear(self.hidden_dim, self.vocab_size) + self.dropout = nn.Dropout(dropout) + + def forward( + self, + cnn_features: List[torch.Tensor], + captions: List[torch.Tensor] = None, + update_masks: torch.Tensor = None, + max_sents: int = 10, + max_len: int = 30, + stop_id: int = None) -> torch.Tensor: + + """Forward propagation + + Args: + cnn_features: a list of tensors where each tensor is of + size [batch_size, feature_dim, spatial_size]. + captions: a list of tensors. + max_len: maximum length of training or generated caption + + Returns: + logits: a tensor + + """ + batch_size = cnn_features[0].size(0) + if captions is not None: + num_sents = captions.size(1) + seq_len = captions.size(2) + else: + num_sents = max_sents + seq_len = max_len + + cnn_feats_t = [ cnn_feat.view(batch_size, self.feature_dim, -1) \ + .permute(0, 2, 1) + for cnn_feat in cnn_features ] + global_feats = [cnn_feat.mean(dim=(2, 3)) for cnn_feat in cnn_features] + + sent_h = self.init_h(torch.cat(global_feats, dim=1)) + sent_c = self.init_c(torch.cat(global_feats, dim=1)) + + word_h = cnn_features[0].new_zeros((batch_size, + self.hidden_dim),dtype=torch.float) + + word_c = cnn_features[0].new_zeros((batch_size, + self.hidden_dim),dtype=torch.float) + + logits = cnn_features[0].new_zeros((batch_size, + num_sents, + seq_len, + self.vocab_size),dtype=torch.float) + # Training phase + if captions is not None: + embeddings = self.embed(captions) + + for k in range(num_sents): + contexts = [self.attend(sent_h, cnn_feat_t)[0] + for cnn_feat_t in cnn_feats_t] + context = torch.cat(contexts, dim=1) + sent_h, sent_c = self.sent_lstm(context, (sent_h, sent_c)) + seq_len_k = update_masks[:, k].sum(dim=1).max().item() + + for t in range(seq_len_k): + batch_mask = update_masks[:, k, t] + + word_h_, word_c_ = self.word_lstm( + torch.cat((embeddings[batch_mask, k, t], + sent_h[batch_mask], + context[batch_mask]), dim=1), + (word_h[batch_mask], word_c[batch_mask])) + + indices = [*batch_mask.unsqueeze(1). \ + repeat(1, self.hidden_dim).nonzero().t()] + word_h = word_h.index_put(indices, word_h_.view(-1)) + word_c = word_c.index_put(indices, word_c_.view(-1)) + logits[batch_mask, k, t] = self.fc( + self.dropout(word_h[batch_mask])) + + return logits + + # Evaluation/Inference phase + else: + x_t = cnn_features[0].new_full((batch_size,), 1, dtype=torch.long) + + for k in range(num_sents): + contexts = [self.attend(sent_h, cnn_feat_t)[0] + for cnn_feat_t in cnn_feats_t] + context = torch.cat(contexts, dim=1) + sent_h, sent_c = self.sent_lstm(context, (sent_h, sent_c)) + + for t in range(seq_len): + embedding = self.embed(x_t) + word_h, word_c = self.word_lstm( + torch.cat((embedding, sent_h, context), dim=1), + (word_h, word_c)) + logit = self.fc(word_h) + x_t = logit.argmax(dim=1) + logits[:, k, t] = logit + + if x_t[0] == stop_id: + break + + return logits.argmax(dim=3) + +class SentSAT(BaseModel): + """Show Attend & Tell model, treating entire caption as one sentence. + This model is based on Show, Attend, and Tell (SAT) paper. The model uses + convolutional neural networks (CNNs) to encode the image and a + recurrent neural network (RNN) with an attention mechanism to generate + the corresponding caption. + + The model consists of three main components: + - Encoder: The encoder is a CNN that extracts a fixed-length feature + vectors from the input image. + - Attention Mechanism: The attention mechanism is used to select + relevant parts of the image at each time step of the RNN, to + generate the next word in the caption. + - Decoder: The decoder is a language model implemented as an RNN that + takes as input the attention-weighted feature vector and generates + a sequence of words, one at a time. + + Args: + dataset: the dataset to train the model. + n_input_images: number of images passed as input to each sample. + label_key: key in the samples to use as label (e.g., "caption"). + tokenizer: pyhealth tokenizer instance created using sample texts. + encoder_pretrained_weights: pretrained state dictionary for encoder. + Default is None. + encoder_freeze_weights: freeze encoder weights so that they are not + updated during training. This is useful when the encoder is trained + separately as a classifier. Default is True. + decoder_embed_dim: decoder embedding dimesion. Default is 256. + decoder_hidden_dim: decoder hidden state dimension. Default is 512. + decoder_feaure_dim: decoder input cell state dimension. + Default is 1024 + decoder_dropout: decoder dropout rate between [0,1]. Default is 0.5 + attention_affine_dim: output dimension of affine layer in attention. + Default is 512. + save_generated_caption: save the generated caption during training. + This is used for evaluating the quality of generated captions. + Default is False. + """ + + def __init__( + self, + dataset: SampleImageCaptionDataset, + n_input_images: int, + label_key: str, + tokenizer: Tokenizer, + encoder_pretrained_weights: Dict[str,float] = None, + encoder_freeze_weights: bool = True, + decoder_embed_dim: int = 256, + decoder_hidden_dim: int = 512, + decoder_feature_dim: int = 1024, + decoder_dropout: float = 0.5, + attention_affine_dim: int = 512, + save_generated_caption: bool = False, + **kwargs + ): + super(SentSAT, self).__init__( + dataset=dataset, + feature_keys=[f'image_{i+1}' for i in range(n_input_images)], + label_key=label_key, + mode="sequence", + save_generated_caption = save_generated_caption + ) + self.n_input_images = n_input_images + + # Encoder component + self.encoder = SentSATEncoder() + if encoder_pretrained_weights: + print(f'Loading encoder pretrained model') + self.encoder.load_state_dict(encoder_pretrained_weights) + if encoder_freeze_weights: + self.encoder.eval() + + # Attention component + self.attention = SentSATAttention(decoder_hidden_dim, + decoder_feature_dim, + attention_affine_dim) + + # Decoder component + self.caption_tokenizer = tokenizer + vocab_size = self.caption_tokenizer.get_vocabulary_size() + self.decoder = SentSATDecoder( self.attention, + vocab_size, + n_input_images, + decoder_feature_dim, + decoder_embed_dim, + decoder_hidden_dim, + decoder_dropout + ) + + def forward( + self, + decoder_maxsents: int =10, + decoder_maxlen:int = 20, + decoder_stop_id: int = None, + **kwargs) -> Dict[str,str]: + """Forward propagation. + + The features `kwargs[self.feature_keys]` is a list of feature keys + associated to every input image. + + The label `kwargs[self.label_key]` is a key of the report caption + for each patient. + + Args: + **kwargs: keyword arguments for the model. The keys must contain + all the feature keys and the label key. + + Returns: + A dictionary with the following keys: + loss: a scalar tensor representing the loss. + y_generated: a dictionary with following key + - "patient_id": list of text representing the generated + text. + e.g. + {123: ["generated text"], 456: ["generated text"]} + y_true: a dictionary with following key + - "patient_id": list of text representing the true text. + e.g. + {123: ["true text"], 456: ["true text"]} + """ + # Initialize the output + output = {"loss": None,"y_generated": [""],"y_true": [""]} + + # Get list of patient_ids + patient_ids = kwargs["patient_id"] + + # Get CNN features + images = self._prepare_batch_images(kwargs) + cnn_features = [self.encoder(image.to(self.device)) + for image in images] + + if self.training: + # Get caption as indicies and corresponding masks + captions, loss_masks, update_masks=self._prepare_batch_captions( + kwargs[self.label_key]) + + # Perform predictions + logits = self.decoder(cnn_features, + captions[:, :, :-1], + update_masks, + decoder_maxsents, + decoder_maxlen, + decoder_stop_id) + + logits = logits.permute(0, 3, 1, 2).contiguous() + captions = captions[:, :, 1:].contiguous() + loss_masks = loss_masks[:, :, 1:].contiguous() + + # Compute loss + loss = self.get_loss_function()(logits, captions) + loss = loss.masked_select(loss_masks).mean() + output["loss"] = loss + + with torch.no_grad(): + output["y_generated"] = self._forward_inference(patient_ids, + decoder_maxsents, + decoder_maxlen, + cnn_features, + decoder_stop_id) + output["y_true"] = self._forward_ground_truths(patient_ids, + kwargs[self.label_key]) + return output + + def _prepare_batch_images(self,kwargs): + """Prepare images for input. + Args: + kwargs: keyword arguments for the model. + Returns: + images: a list of input images represented as tensors. Every tensor + in the list has shape [batch_size,3,image_size,image_size] + """ + images = [torch.stack(kwargs[image_feature], 0) + for image_feature in self.feature_keys] + + return images + + def _prepare_batch_captions( + self, + captions:List[List[str]] + ) -> (torch.Tensor,torch.Tensor): + """Prepare caption idx for input. + + Args: + captions: list of captions. Each caption is a list of list, where + each list represents a sentence in the caption. + Following is an example of a caption + [ + ["","first","sentence","."], + [".", "second", "sentence",".", ""] + ] + Returns: + captions_idx: an int tensor + masks: a bool tensor + """ + x = self.caption_tokenizer.batch_encode_3d(captions) + captions_idx = torch.tensor(x, dtype=torch.long, + device=self.device) + + loss_masks = torch.zeros_like(captions_idx,dtype=torch.bool) + update_masks = torch.zeros_like(captions_idx,dtype=torch.bool) + + for icap, cap in enumerate(captions_idx): + for isent, sent in enumerate(cap): + l = len(sent[sent !=0]) + if l==0: continue + loss_masks[icap, isent, 1:l].fill_(1) + update_masks[icap, isent, :l-1].fill_(1) + + return captions_idx, loss_masks, update_masks + + def _forward_inference( + self, + patient_ids: List[int], + decoder_maxsents: int, + decoder_maxlen: int, + cnn_features:List[torch.Tensor], + decoder_stop_id: int = None) -> Dict[int,str]: + """Forward propagation during inference + + Args: + patient_ids: a list of patient ids + cnn_features: a list of tensors + + Returns: + generated_results: a dict with following keys + - patient_id: int + - generated_results: List[str] + """ + generated_results = {} + for idx, patient_id in enumerate(patient_ids): + generated_results[patient_id] = [""] + cnn_feature = [cnn_feat[idx].unsqueeze(0) + for cnn_feat in cnn_features] + pred = self.decoder(cnn_feature, None, None, + decoder_maxsents,decoder_maxlen, + decoder_stop_id)[0] + pred = pred.detach().cpu() + generated_results[patient_id] = [""] + for isent in range(pred.size(0)): + pred_tokens = self.caption_tokenizer \ + .convert_indices_to_tokens(pred[isent].tolist()) + + words = [] + for token in pred_tokens: + if token == '' or token == '': + continue + if token == '': + break + words.append(token) + if len(words) < 2: + continue + generated_results[patient_id][0] += " ".join(words) + generated_results[patient_id][0] += " " + generated_results[patient_id][0] = generated_results[patient_id][0].replace(". .",'.') + + return generated_results + + def _forward_ground_truths( + self, + patient_ids: List[int], + captions:List[List[str]]) -> Dict[int,str]: + """Forward propagation for ground truth + + Args: + patient_ids: a list of patient ids + cnn_features: a list of tensors + + Returns: + ground_results: a dict with following keys + - patient_id: int + - generated_results: List[str] + """ + ground_truths = {} + for idx, caption in enumerate(captions): + ground_truths[patient_ids[idx]] = [""] + tokens = [] + tokens.extend(flatten_list(caption)) + ground_truths[patient_ids[idx]][0] = ' '.join(tokens) \ + .replace('. .','.') \ + .replace("","") \ + .replace("","") \ + .strip() + + return ground_truths \ No newline at end of file diff --git a/pyhealth/models/wordsat.py b/pyhealth/models/wordsat.py index 158fd29c..f6dd0a11 100644 --- a/pyhealth/models/wordsat.py +++ b/pyhealth/models/wordsat.py @@ -11,44 +11,88 @@ from pyhealth.datasets.utils import flatten_list class WordSATEncoder(nn.Module): - """ - """ + """ SAT CNN(Densenet121) Encoder model""" def __init__(self): super().__init__() self.densenet121 = models.densenet121(weights='DEFAULT') self.densenet121.classifier = nn.Identity() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward propagation. + Extract fixed-length feature vectors from the input image. + + Args: + x: A tensor of tranfomed image of size + [batch_size,3,224,224] + Return: + x: A tensor of image feature vectors of size + [batch_size,1024,16,16] + """ x = self.densenet121.features(x) x = F.relu(x) return x class WordSATAttention(nn.Module): + """SAT Attention Module + + Computes a set of attention weights based on the current hidden state of + the RNN and the feature vectors from the CNN, which are then used to + compute a weighted average of the feature vectors. + + Args: + k_size: key vector size + v_size: value vector size + affine_dim: affine dimension. Default is 512 """ - """ - def __init__(self, k_size, v_size, affine_dim=512): + def __init__( + self, + k_size: int, + v_size: int, + affine_dim: int =512): super().__init__() self.affine_k = nn.Linear(k_size, affine_dim, bias=False) self.affine_v = nn.Linear(v_size, affine_dim, bias=False) self.affine = nn.Linear(affine_dim, 1, bias=False) - def forward(self, k, v): - # k: batch size x hidden size - # v: batch size x spatial size x hidden size - # z: batch size x spatial size - # TODO other ways of attention? + def forward( + self, + k: torch.Tensor, + v: torch.Tensor) -> (torch.Tensor,torch.Tensor): + """Forward propagation + + Args: + k: a tensor of size [batch_size, hidden_dim] + v: a tensor of size [batch_size, spatial_size, hidden_dim] + + Returns: + context: a tensor of size [batch_size, feature_dim] + alpha: a tensor of size [batch_size, spatial_size] + """ content_v = self.affine_k(k).unsqueeze(1) + self.affine_v(v) + # z: batch size x spatial size z = self.affine(torch.tanh(content_v)).squeeze(2) alpha = torch.softmax(z, dim=1) context = (v * alpha.unsqueeze(2)).sum(dim=1) + return context, alpha class WordSATDecoder(nn.Module): - """ + """ Word SAT decoder model for one sentence + + An LSTM based model that takes as input the attention-weighted feature + vector and generates a sequence of words, one at a time. + + Args: + attention: attention module instance + vocab_size: vocabulary size + n_encoder_inputs: number of image inputs given to the encoder + feature_dim: encoder output feature dimesion + hidden_dim: LSTM hidden dimension + dropout: dropout rate between [0,1] """ def __init__( self, - attention: WordSATAttention, + attention: object, vocab_size: int, n_encoder_inputs: int, feature_dim: int, @@ -62,7 +106,7 @@ def __init__( self.feature_dim = feature_dim self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim - self.dropout = dropout + self.attend = attention self.embed = nn.Embedding(self.vocab_size, self.embedding_dim) self.init_h = nn.Linear(self.n_encoder_inputs * self.feature_dim, @@ -73,9 +117,26 @@ def __init__( self.n_encoder_inputs * feature_dim, hidden_dim) self.fc = nn.Linear(self.hidden_dim, self.vocab_size) - self.dropout = nn.Dropout(self.dropout) + self.dropout = nn.Dropout(dropout) + + def forward( + self, + cnn_features: List[torch.Tensor], + captions: List[torch.Tensor] = None, + max_len: int = 100) -> torch.Tensor: + + """Forward propagation - def forward(self, cnn_features, captions=None, max_len=100): + Args: + cnn_features: a list of tensors where each tensor is of + size [batch_size, feature_dim, spatial_size]. + captions: a list of tensors. + max_len: maximum length of training or generated caption + + Returns: + logits: a tensor + + """ batch_size = cnn_features[0].size(0) if captions is not None: seq_len = captions.size(1) @@ -93,8 +154,8 @@ def forward(self, cnn_features, captions=None, max_len=100): logits = cnn_features[0].new_zeros((batch_size, seq_len, self.vocab_size),dtype=torch.float) - - if captions: + # Training phase + if captions is not None: embeddings = self.embed(captions) for t in range(seq_len): contexts = [self.attend(h, cnn_feat_t)[0] @@ -107,11 +168,12 @@ def forward(self, cnn_features, captions=None, max_len=100): return logits + # Evaluation/Inference phase else: x_t = cnn_features[0].new_full((batch_size,), 1, dtype=torch.long) for t in range(seq_len): embedding = self.embed(x_t) - contexts = [self.atten(h, cnn_feat_t)[0] + contexts = [self.attend(h, cnn_feat_t)[0] for cnn_feat_t in cnn_feats_t] context = torch.cat(contexts, dim=1) h, c = self.lstmcell(torch.cat((embedding, context), dim=1), @@ -123,7 +185,7 @@ def forward(self, cnn_features, captions=None, max_len=100): return logits.argmax(dim=2) class WordSAT(BaseModel): - """Word Show Attend & Tell model. + """Show Attend & Tell model, treating entire caption as one sentence. This model is based on Show, Attend, and Tell (SAT) paper. The model uses convolutional neural networks (CNNs) to encode the image and a recurrent neural network (RNN) with an attention mechanism to generate @@ -134,10 +196,7 @@ class WordSAT(BaseModel): vectors from the input image. - Attention Mechanism: The attention mechanism is used to select relevant parts of the image at each time step of the RNN, to - generate the next word in the caption. It computes a set of - attention weights based on the current hidden state of the RNN and - the feature vectors from the CNN, which are then used to compute a - weighted average of the feature vectors. + generate the next word in the caption. - Decoder: The decoder is a language model implemented as an RNN that takes as input the attention-weighted feature vector and generates a sequence of words, one at a time. @@ -152,13 +211,11 @@ class WordSAT(BaseModel): encoder_freeze_weights: freeze encoder weights so that they are not updated during training. This is useful when the encoder is trained separately as a classifier. Default is True. - decder_maxlen: maximum caption length used during training or generated - during inference. Default is 100. decoder_embed_dim: decoder embedding dimesion. Default is 256. decoder_hidden_dim: decoder hidden state dimension. Default is 512. decoder_feaure_dim: decoder input cell state dimension. Default is 1024 - decoder_dropout: decoder dropout rate [0,1]. Default is 0.5 + decoder_dropout: decoder dropout rate between [0,1]. Default is 0.5 attention_affine_dim: output dimension of affine layer in attention. Default is 512. save_generated_caption: save the generated caption during training. @@ -174,7 +231,6 @@ def __init__( tokenizer: Tokenizer, encoder_pretrained_weights: Dict[str,float] = None, encoder_freeze_weights: bool = True, - decoder_maxlen: int = 100, decoder_embed_dim: int = 256, decoder_hidden_dim: int = 512, decoder_feature_dim: int = 1024, @@ -206,19 +262,18 @@ def __init__( attention_affine_dim) # Decoder component - self.decoder_maxlen = decoder_maxlen self.caption_tokenizer = tokenizer vocab_size = self.caption_tokenizer.get_vocabulary_size() - self.decoder = WordSATDecoder(self.attention, - vocab_size, - n_input_images, - decoder_feature_dim, - decoder_embed_dim, - decoder_hidden_dim, - decoder_dropout - ) - - def forward(self, **kwargs): + self.decoder = WordSATDecoder( self.attention, + vocab_size, + n_input_images, + decoder_feature_dim, + decoder_embed_dim, + decoder_hidden_dim, + decoder_dropout + ) + + def forward(self, decoder_maxlen:int = 100, **kwargs) -> Dict[str,str]: """Forward propagation. The features `kwargs[self.feature_keys]` is a list of feature keys @@ -228,6 +283,8 @@ def forward(self, **kwargs): for each patient. Args: + decder_maxlen: maximum caption length used during training or + generated during inference. Default is 100. **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. @@ -251,17 +308,17 @@ def forward(self, **kwargs): patient_ids = kwargs["patient_id"] # Get CNN features - images = self._prepare_batch_image(kwargs) + images = self._prepare_batch_images(kwargs) cnn_features = [self.encoder(image.to(self.device)) for image in images] if self.training: - # Get caption indexes and masks + # Get caption as indicies and corresponding masks captions,masks=self._prepare_batch_captions(kwargs[self.label_key]) # Perform predictions logits = self.decoder(cnn_features, captions[:,:-1], - self.decoder_maxlen) + decoder_maxlen) logits = logits.permute(0, 2, 1).contiguous() captions = captions[:, 1:].contiguous() masks = masks[:, 1:].contiguous() @@ -270,11 +327,13 @@ def forward(self, **kwargs): loss = self.get_loss_function()(logits, captions) loss = loss.masked_select(masks).mean() output["loss"] = loss - else: + + with torch.no_grad(): output["y_generated"] = self._forward_inference(patient_ids, + decoder_maxlen, cnn_features) - output["y_true"] = self._forward_ground_truths(patient_ids, - kwargs[self.label_key]) + output["y_true"] = self._forward_ground_truths(patient_ids, + kwargs[self.label_key]) return output def _prepare_batch_images(self,kwargs): @@ -290,7 +349,10 @@ def _prepare_batch_images(self,kwargs): return images - def _prepare_batch_captions(self,captions): + def _prepare_batch_captions( + self, + captions:List[List[str]] + ) -> (torch.Tensor,torch.Tensor): """Prepare caption idx for input. Args: @@ -302,33 +364,48 @@ def _prepare_batch_captions(self,captions): [".", "second", "sentence",".", ""] ] Returns: - captions_idx: an int tensor of size [batch_size,max_caption_length] - masks: a bool tensor of size [batch_size,max_caption_length] + captions_idx: an int tensor + masks: a bool tensor """ - samples = [] + # Combine all sentences in each caption to create a single sentence + samples = [] for caption in captions: tokens = [] tokens.extend(flatten_list(caption)) text = ' '.join(tokens).replace('. .','.') samples.append([text.split()]) - + x = self.caption_tokenizer.batch_encode_3d(samples) - captions_idx = torch.tensor(x, dtype=torch.long, device=self.device) + captions_idx = torch.tensor(x, dtype=torch.long, + device=self.device) masks = torch.sum(captions_idx,dim=1) !=0 - captions_idx = captions_idx.squeeze(1) - + captions_idx = captions_idx.squeeze(1) + return captions_idx,masks - def _forward_inference(self,patient_ids,cnn_features): - """ + def _forward_inference( + self, + patient_ids: List[int], + decoder_maxlen: int, + cnn_features:List[torch.Tensor]) -> Dict[int,str]: + """Forward propagation during inference + + Args: + patient_ids: a list of patient ids + cnn_features: a list of tensors + + Returns: + generated_results: a dict with following keys + - patient_id: int + - generated_results: List[str] """ generated_results = {} for idx, patient_id in enumerate(patient_ids): generated_results[patient_id] = [""] cnn_feature = [cnn_feat[idx].unsqueeze(0) for cnn_feat in cnn_features] - pred = self.decoder(cnn_feature, None, self.decoder_maxlen)[0] + pred = self.decoder(cnn_feature, None, decoder_maxlen)[0] pred = pred.detach().cpu() pred_tokens = self.caption_tokenizer \ .convert_indices_to_tokens(pred.tolist()) @@ -345,8 +422,20 @@ def _forward_inference(self,patient_ids,cnn_features): return generated_results - def _forward_ground_truths(self,patient_ids,captions): - """ + def _forward_ground_truths( + self, + patient_ids: List[int], + captions:List[List[str]]) -> Dict[int,str]: + """Forward propagation for ground truth + + Args: + patient_ids: a list of patient ids + cnn_features: a list of tensors + + Returns: + ground_results: a dict with following keys + - patient_id: int + - generated_results: List[str] """ ground_truths = {} for idx, caption in enumerate(captions): diff --git a/pyhealth/tasks/xray_report_generation.py b/pyhealth/tasks/xray_report_generation.py index 753e7327..d1fcc6fa 100644 --- a/pyhealth/tasks/xray_report_generation.py +++ b/pyhealth/tasks/xray_report_generation.py @@ -15,7 +15,7 @@ def biview_multisent_fn(patient): - Impression, describing the overall interpretation of X-ray results Args: - patient: a dictionary of patient X-ray report with following keys + patient: a list of dictionary of patient X-ray report with below keys - patient_id: type(int) unique identifier for patient - frontal_img_path: type(str) @@ -42,6 +42,7 @@ def biview_multisent_fn(patient): first and end of last sentence. These are mandatory for the task. """ sample = {} + patient = patient[0] report = f"{patient['impression']} . {patient['findings']}" caption = [] diff --git a/pyhealth/trainer.py b/pyhealth/trainer.py index d901f07c..e43fcd80 100644 --- a/pyhealth/trainer.py +++ b/pyhealth/trainer.py @@ -335,10 +335,11 @@ def inference_sequence(self, dataloader) -> Dict[int, str]: y_pred_all[key] = y_generated[key] if self.model.save_generated_caption: + fname = datetime.now().strftime("%Y%m%d-%H%M%S") with open(os.path.join(self.exp_path, - f'val_e{self.current_epoch}.csv'),'w') as f1: + f"{fname}_gen{self.current_epoch}.csv"),"w") as f1: with open(os.path.join(self.exp_path, - 'val_gts.csv'), 'w') as f2: + f"{fname}_gts.csv"), "w") as f2: for patient_id in y_pred_all.keys(): f1.write(y_pred_all[patient_id][0] + '\n') f2.write(y_true_all[patient_id][0] + '\n') diff --git a/test.py b/test.py index 9262d93c..54904f13 100644 --- a/test.py +++ b/test.py @@ -4,7 +4,7 @@ from pyhealth.tasks.xray_report_generation import biview_multisent_fn from pyhealth.datasets import split_by_patient, get_dataloader from pyhealth.tokenizer import Tokenizer -from pyhealth.models import WordSAT +from pyhealth.models import WordSAT, SentSAT from pyhealth.trainer import Trainer from pyhealth.datasets.utils import list_nested_levels, flatten_list @@ -73,26 +73,24 @@ def seed_everything(seed: int): sample_dataset,[0.8,0.1,0.1] ) -train_dataloader = get_dataloader(train_dataset,batch_size=8,shuffle=True) -val_dataloader = get_dataloader(val_dataset,batch_size=1,shuffle=False) +train_dataloader = get_dataloader(train_dataset,batch_size=16,shuffle=True) +val_dataloader = get_dataloader(val_dataset,batch_size=2,shuffle=False) test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False) -print(len(train_dataset),len(val_dataset),len(test_dataset)) - -model=WordSAT(dataset=sample_dataset, - feature_keys=['image_1','image_2'], +model=SentSAT( + dataset=sample_dataset, + n_input_images = 2, label_key='caption', tokenizer=tokenizer, - mode='sequence', encoder_pretrained_weights=state_dict, save_generated_caption = True ) #model.eval() #data = next(iter(val_dataloader)) #print(model(**data)) -""" + output_path = '/home/keshari2/ChestXrayReporting/IU_XRay/src/output/pyhealth' -ckpt_path = '/home/keshari2/ChestXrayReporting/IU_XRay/src/output/pyhealth/20230422-005011/best.ckpt' +ckpt_path = '/home/keshari2/ChestXrayReporting/IU_XRay/src/output/pyhealth/20230424-114914/best.ckpt' trainer = Trainer( model=model, @@ -104,8 +102,7 @@ def seed_everything(seed: int): val_dataloader = val_dataloader, optimizer_params = {"lr": 1e-4}, weight_decay = 1e-5, - max_grad_norm = 1, + #max_grad_norm = 1, epochs = 5, monitor = 'Bleu_1' -) -""" \ No newline at end of file +) \ No newline at end of file From d4aad7fd1687447a83ac3716f37a454c9089c24d Mon Sep 17 00:00:00 2001 From: samarthkeshari Date: Mon, 24 Apr 2023 18:50:11 -0700 Subject: [PATCH 08/11] updated documetation --- pyhealth/metrics/sequence.py | 6 +-- pyhealth/models/base_model.py | 2 +- pyhealth/models/sentsat.py | 92 ++++++++++++++++++++--------------- pyhealth/models/wordsat.py | 53 ++++++++++---------- pyhealth/trainer.py | 10 ++-- 5 files changed, 89 insertions(+), 74 deletions(-) diff --git a/pyhealth/metrics/sequence.py b/pyhealth/metrics/sequence.py index d59ee693..571644ce 100644 --- a/pyhealth/metrics/sequence.py +++ b/pyhealth/metrics/sequence.py @@ -7,11 +7,11 @@ from pycocoevalcap.meteor.meteor import Meteor def sequence_metrics_fn( - y_true: Dict[int,str], - y_generated: Dict[int,str], + y_true: List[Dict[int,str]], + y_generated: List[Dict[int,str]], metrics: Optional[List[str]] = None ) -> Dict[str, float]: - """ + """Compute metrics relevant for evaluating sequences """ scorers = [ (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), diff --git a/pyhealth/models/base_model.py b/pyhealth/models/base_model.py index a905e5d8..1f87d244 100644 --- a/pyhealth/models/base_model.py +++ b/pyhealth/models/base_model.py @@ -10,7 +10,7 @@ from pyhealth.tokenizer import Tokenizer # TODO: add support for regression -VALID_MODE = ["binary", "multiclass", "multilabel","sequence"] +VALID_MODE = ["binary", "multiclass", "multilabel", "sequence"] class BaseModel(ABC, nn.Module): diff --git a/pyhealth/models/sentsat.py b/pyhealth/models/sentsat.py index 45f99ad8..61334a0f 100644 --- a/pyhealth/models/sentsat.py +++ b/pyhealth/models/sentsat.py @@ -18,14 +18,14 @@ def __init__(self): self.densenet121.classifier = nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward propagation. + """Forward propagation. Extract fixed-length feature vectors from the input image. - + Args: - x: A tensor of tranfomed image of size - [batch_size,3,224,224] + x: A tensor of transfomed image of size + [batch_size,3,512,512] Return: - x: A tensor of image feature vectors of size + x: A tensor of image feature vectors of size [batch_size,1024,16,16] """ x = self.densenet121.features(x) @@ -35,8 +35,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SentSATAttention(nn.Module): """SAT Attention Module - Computes a set of attention weights based on the current hidden state of - the RNN and the feature vectors from the CNN, which are then used to + Computes a set of attention weights based on the current hidden state of + the RNN and the feature vectors from the CNN, which are then used to compute a weighted average of the feature vectors. Args: @@ -45,9 +45,9 @@ class SentSATAttention(nn.Module): affine_dim: affine dimension. Default is 512 """ def __init__( - self, - k_size: int, - v_size: int, + self, + k_size: int, + v_size: int, affine_dim: int =512): super().__init__() self.affine_k = nn.Linear(k_size, affine_dim, bias=False) @@ -55,8 +55,8 @@ def __init__( self.affine = nn.Linear(affine_dim, 1, bias=False) def forward( - self, - k: torch.Tensor, + self, + k: torch.Tensor, v: torch.Tensor) -> (torch.Tensor,torch.Tensor): """Forward propagation @@ -77,18 +77,20 @@ def forward( return context, alpha class SentSATDecoder(nn.Module): - """ Word SAT decoder model for one sentence + """ Sentence SAT decoder model that treats a caption as multiple sentences - An LSTM based model that takes as input the attention-weighted feature - vector and generates a sequence of words, one at a time. + An LSTM based model that takes as input the attention-weighted feature + vector and generates a sequence of sentences and followed by words, in each + sentence Args: attention: attention module instance vocab_size: vocabulary size n_encoder_inputs: number of image inputs given to the encoder - feature_dim: encoder output feature dimesion + feature_dim: encoder output feature dimension + embedding_dim: decoder embedding dimension hidden_dim: LSTM hidden dimension - dropout: dropout rate between [0,1] + dropout: dropout rate between [0,1] """ def __init__( self, @@ -115,7 +117,7 @@ def __init__( hidden_dim) self.sent_lstm = nn.LSTMCell(self.n_encoder_inputs * self.feature_dim, hidden_dim) - + self.word_lstm = nn.LSTMCell(self.embedding_dim + self.hidden_dim + self.n_encoder_inputs * feature_dim, hidden_dim) @@ -123,21 +125,25 @@ def __init__( self.dropout = nn.Dropout(dropout) def forward( - self, - cnn_features: List[torch.Tensor], + self, + cnn_features: List[torch.Tensor], captions: List[torch.Tensor] = None, update_masks: torch.Tensor = None, max_sents: int = 10, max_len: int = 30, stop_id: int = None) -> torch.Tensor: - + """Forward propagation Args: cnn_features: a list of tensors where each tensor is of size [batch_size, feature_dim, spatial_size]. captions: a list of tensors. + updat_masks: a boolean tensor to identify the actual tokens + max_sents: maximum number of sentences that can be generated max_len: maximum length of training or generated caption + stop_id: token id from vocabulary to stop word generation for a + sentence during inference Returns: logits: a tensor @@ -161,7 +167,7 @@ def forward( word_h = cnn_features[0].new_zeros((batch_size, self.hidden_dim),dtype=torch.float) - + word_c = cnn_features[0].new_zeros((batch_size, self.hidden_dim),dtype=torch.float) @@ -182,10 +188,10 @@ def forward( for t in range(seq_len_k): batch_mask = update_masks[:, k, t] - + word_h_, word_c_ = self.word_lstm( - torch.cat((embeddings[batch_mask, k, t], - sent_h[batch_mask], + torch.cat((embeddings[batch_mask, k, t], + sent_h[batch_mask], context[batch_mask]), dim=1), (word_h[batch_mask], word_c[batch_mask])) @@ -201,17 +207,17 @@ def forward( # Evaluation/Inference phase else: x_t = cnn_features[0].new_full((batch_size,), 1, dtype=torch.long) - + for k in range(num_sents): contexts = [self.attend(sent_h, cnn_feat_t)[0] for cnn_feat_t in cnn_feats_t] context = torch.cat(contexts, dim=1) sent_h, sent_c = self.sent_lstm(context, (sent_h, sent_c)) - + for t in range(seq_len): embedding = self.embed(x_t) word_h, word_c = self.word_lstm( - torch.cat((embedding, sent_h, context), dim=1), + torch.cat((embedding, sent_h, context), dim=1), (word_h, word_c)) logit = self.fc(word_h) x_t = logit.argmax(dim=1) @@ -285,7 +291,7 @@ def __init__( save_generated_caption = save_generated_caption ) self.n_input_images = n_input_images - + # Encoder component self.encoder = SentSATEncoder() if encoder_pretrained_weights: @@ -312,10 +318,10 @@ def __init__( ) def forward( - self, - decoder_maxsents: int =10, - decoder_maxlen:int = 20, - decoder_stop_id: int = None, + self, + decoder_maxsents: int =10, + decoder_maxlen:int = 20, + decoder_stop_id: int = None, **kwargs) -> Dict[str,str]: """Forward propagation. @@ -359,7 +365,7 @@ def forward( kwargs[self.label_key]) # Perform predictions - logits = self.decoder(cnn_features, + logits = self.decoder(cnn_features, captions[:, :, :-1], update_masks, decoder_maxsents, @@ -374,7 +380,7 @@ def forward( loss = self.get_loss_function()(logits, captions) loss = loss.masked_select(loss_masks).mean() output["loss"] = loss - + with torch.no_grad(): output["y_generated"] = self._forward_inference(patient_ids, decoder_maxsents, @@ -414,12 +420,13 @@ def _prepare_batch_captions( ] Returns: captions_idx: an int tensor - masks: a bool tensor + loss_masks: a bool tensor for each sentence in a caption + update_masks: a bool tensor for each sentence in a caption """ x = self.caption_tokenizer.batch_encode_3d(captions) - captions_idx = torch.tensor(x, dtype=torch.long, + captions_idx = torch.tensor(x, dtype=torch.long, device=self.device) - + loss_masks = torch.zeros_like(captions_idx,dtype=torch.bool) update_masks = torch.zeros_like(captions_idx,dtype=torch.bool) @@ -429,7 +436,7 @@ def _prepare_batch_captions( if l==0: continue loss_masks[icap, isent, 1:l].fill_(1) update_masks[icap, isent, :l-1].fill_(1) - + return captions_idx, loss_masks, update_masks def _forward_inference( @@ -443,7 +450,12 @@ def _forward_inference( Args: patient_ids: a list of patient ids + decoder_maxsents: maximum number of sentences that can be generated + decoder_maxlen: maximum length of words in a every sentence of a + caption cnn_features: a list of tensors + stop_id: token id from vocabulary to stop word generation for a + sentence Returns: generated_results: a dict with following keys @@ -463,7 +475,7 @@ def _forward_inference( for isent in range(pred.size(0)): pred_tokens = self.caption_tokenizer \ .convert_indices_to_tokens(pred[isent].tolist()) - + words = [] for token in pred_tokens: if token == '' or token == '': diff --git a/pyhealth/models/wordsat.py b/pyhealth/models/wordsat.py index f6dd0a11..3ac42b4f 100644 --- a/pyhealth/models/wordsat.py +++ b/pyhealth/models/wordsat.py @@ -18,14 +18,14 @@ def __init__(self): self.densenet121.classifier = nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward propagation. + """Forward propagation. Extract fixed-length feature vectors from the input image. - + Args: - x: A tensor of tranfomed image of size - [batch_size,3,224,224] + x: A tensor of transfomed image of size + [batch_size,3,512,512] Return: - x: A tensor of image feature vectors of size + x: A tensor of image feature vectors of size [batch_size,1024,16,16] """ x = self.densenet121.features(x) @@ -35,8 +35,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class WordSATAttention(nn.Module): """SAT Attention Module - Computes a set of attention weights based on the current hidden state of - the RNN and the feature vectors from the CNN, which are then used to + Computes a set of attention weights based on the current hidden state of + the RNN and the feature vectors from the CNN, which are then used to compute a weighted average of the feature vectors. Args: @@ -45,9 +45,9 @@ class WordSATAttention(nn.Module): affine_dim: affine dimension. Default is 512 """ def __init__( - self, - k_size: int, - v_size: int, + self, + k_size: int, + v_size: int, affine_dim: int =512): super().__init__() self.affine_k = nn.Linear(k_size, affine_dim, bias=False) @@ -55,8 +55,8 @@ def __init__( self.affine = nn.Linear(affine_dim, 1, bias=False) def forward( - self, - k: torch.Tensor, + self, + k: torch.Tensor, v: torch.Tensor) -> (torch.Tensor,torch.Tensor): """Forward propagation @@ -79,16 +79,17 @@ def forward( class WordSATDecoder(nn.Module): """ Word SAT decoder model for one sentence - An LSTM based model that takes as input the attention-weighted feature + An LSTM based model that takes as input the attention-weighted feature vector and generates a sequence of words, one at a time. Args: attention: attention module instance vocab_size: vocabulary size n_encoder_inputs: number of image inputs given to the encoder - feature_dim: encoder output feature dimesion + feature_dim: encoder output feature dimension + embedding_dim: decoder embedding dimension hidden_dim: LSTM hidden dimension - dropout: dropout rate between [0,1] + dropout: dropout rate between [0,1] """ def __init__( self, @@ -120,11 +121,11 @@ def __init__( self.dropout = nn.Dropout(dropout) def forward( - self, - cnn_features: List[torch.Tensor], - captions: List[torch.Tensor] = None, + self, + cnn_features: List[torch.Tensor], + captions: List[torch.Tensor] = None, max_len: int = 100) -> torch.Tensor: - + """Forward propagation Args: @@ -244,6 +245,7 @@ def __init__( feature_keys=[f'image_{i+1}' for i in range(n_input_images)], label_key=label_key, mode="sequence", + save_generated_caption = save_generated_caption ) self.n_input_images = n_input_images self.save_generated_caption = save_generated_caption @@ -327,7 +329,7 @@ def forward(self, decoder_maxlen:int = 100, **kwargs) -> Dict[str,str]: loss = self.get_loss_function()(logits, captions) loss = loss.masked_select(masks).mean() output["loss"] = loss - + with torch.no_grad(): output["y_generated"] = self._forward_inference(patient_ids, decoder_maxlen, @@ -367,7 +369,7 @@ def _prepare_batch_captions( captions_idx: an int tensor masks: a bool tensor """ - + # Combine all sentences in each caption to create a single sentence samples = [] for caption in captions: @@ -375,13 +377,13 @@ def _prepare_batch_captions( tokens.extend(flatten_list(caption)) text = ' '.join(tokens).replace('. .','.') samples.append([text.split()]) - + x = self.caption_tokenizer.batch_encode_3d(samples) - captions_idx = torch.tensor(x, dtype=torch.long, + captions_idx = torch.tensor(x, dtype=torch.long, device=self.device) masks = torch.sum(captions_idx,dim=1) !=0 - captions_idx = captions_idx.squeeze(1) - + captions_idx = captions_idx.squeeze(1) + return captions_idx,masks def _forward_inference( @@ -393,6 +395,7 @@ def _forward_inference( Args: patient_ids: a list of patient ids + decoder_maxlen: maximum length of generated caption cnn_features: a list of tensors Returns: diff --git a/pyhealth/trainer.py b/pyhealth/trainer.py index e43fcd80..84579bf4 100644 --- a/pyhealth/trainer.py +++ b/pyhealth/trainer.py @@ -320,7 +320,7 @@ def load_ckpt(self, ckpt_path: str) -> None: return def inference_sequence(self, dataloader) -> Dict[int, str]: - """Model inference + """Model inference for sequences """ y_true_all = {} y_pred_all = {} @@ -333,18 +333,18 @@ def inference_sequence(self, dataloader) -> Dict[int, str]: for key in y_generated.keys(): y_true_all[key] = y_true[key] y_pred_all[key] = y_generated[key] - + if self.model.save_generated_caption: fname = datetime.now().strftime("%Y%m%d-%H%M%S") - with open(os.path.join(self.exp_path, + with open(os.path.join(self.exp_path, f"{fname}_gen{self.current_epoch}.csv"),"w") as f1: - with open(os.path.join(self.exp_path, + with open(os.path.join(self.exp_path, f"{fname}_gts.csv"), "w") as f2: for patient_id in y_pred_all.keys(): f1.write(y_pred_all[patient_id][0] + '\n') f2.write(y_true_all[patient_id][0] + '\n') - + return y_true_all, y_pred_all, 0 From 6e1db61d2c7a28180c30805a180c8775e3cbef4e Mon Sep 17 00:00:00 2001 From: samarthkeshari Date: Wed, 26 Apr 2023 20:58:19 -0500 Subject: [PATCH 09/11] added exmple code for xray report generation --- examples/xray_report_generation_sat.py | 178 ++++++++++++++++++++++++ test.py | 182 ++++++++++++++++++++++++- 2 files changed, 359 insertions(+), 1 deletion(-) create mode 100644 examples/xray_report_generation_sat.py diff --git a/examples/xray_report_generation_sat.py b/examples/xray_report_generation_sat.py new file mode 100644 index 00000000..3447db3e --- /dev/null +++ b/examples/xray_report_generation_sat.py @@ -0,0 +1,178 @@ +import os +import argparse +import pickle +import torch +from collections import OrderedDict +from torchvision import transforms +from pyhealth.datasets import BaseImageCaptionDataset +from pyhealth.tasks.xray_report_generation import biview_multisent_fn +from pyhealth.datasets import split_by_patient, get_dataloader +from pyhealth.tokenizer import Tokenizer +from pyhealth.models import WordSAT, SentSAT +from pyhealth.trainer import Trainer + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('--root', type=str, default=".") + parser.add_argument('--encoder-chkpt-fname', type=str, default=None) + parser.add_argument('--tokenizer-fname', type=str, default=None) + parser.add_argument('--num-epochs', type=int, default=1) + parser.add_argument('--model-type', type=str, default="wordsat") + + args = parser.parse_args() + return args + +def seed_everything(seed: int): + import random, os + import numpy as np + import torch + + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + + +# STEP 1: load data +def load_data(root): + base_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay') + return base_dataset + +# STEP 2: set task +def set_task(base_dataset): + sample_dataset = base_dataset.set_task(biview_multisent_fn) + transform = transforms.Compose([ + transforms.RandomAffine(degrees=30), + transforms.Resize((512,512)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485,0.456,0.406], + std=[0.229,0.224,0.225]), + ]) + sample_dataset.set_transform(transform) + return sample_dataset + +# STEP 3: get dataloaders +def get_dataloaders(sample_dataset): + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset,[0.8,0.1,0.1]) + + train_dataloader = get_dataloader(train_dataset,batch_size=8,shuffle=True) + val_dataloader = get_dataloader(val_dataset,batch_size=1,shuffle=False) + test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False) + + return train_dataloader,val_dataloader,test_dataloader + +# STEP 4: get tokenizer +def get_tokenizer(root,sample_dataset=None,tokenizer_fname=None): + if tokenizer_fname: + with open(os.path.join(root,tokenizer_fname), 'wb') as f: + tokenizer = pickle.load(f) + else: + # should always be first element in the list of special tokens + special_tokens = ['','','',''] + tokenizer = Tokenizer( + sample_dataset.get_all_tokens(key='caption'), + special_tokens=special_tokens, + ) + return tokenizer + +# STEP 5: get encoder pretrained state dictionary +def extract_encoder_state_dict(root,chkpt_fname): + checkpoint = torch.load(os.path.join(root,chkpt_fname) ) + state_dict = OrderedDict() + for k,v in checkpoint['state_dict'].items(): + if 'classifier' in k: continue + if k[:7] == 'module.' : + name = k[7:] + else: + name = k + + name = name.replace('classifier.0.','classifier.') + state_dict[name] = v + return state_dict + +# STEP 6: define model +def define_model( + sample_dataset, + tokenizer, + encoder_weights, + model_type='wordsat'): + + if model_type == 'wordsat': + model=WordSAT( + dataset = sample_dataset, + n_input_images = 2, + label_key = 'caption', + tokenizer = tokenizer, + encoder_pretrained_weights = encoder_weights, + encoder_freeze_weights = True, + save_generated_caption = True + ) + else: + model=SentSAT( + dataset = sample_dataset, + n_input_images = 2, + label_key = 'caption', + tokenizer = tokenizer, + encoder_pretrained_weights = encoder_weights, + encoder_freeze_weights = True, + save_generated_caption = True + ) + + return model + +# STEP 7: run trainer +def run_trainer(output_path, train_dataloader, val_dataloader, model,n_epochs): + trainer = Trainer( + model=model, + output_path=output_path + ) + + trainer.train( + train_dataloader = train_dataloader, + val_dataloader = val_dataloader, + optimizer_params = {"lr": 1e-4}, + weight_decay = 1e-5, + epochs = n_epochs, + monitor = 'Bleu_1' + ) + return trainer + +# STEP 8: evaluate +def evaluate(trainer,test_dataloader): + print(trainer.evaluate(test_dataloader)) + return None + +if __name__ == '__main__': + args = get_args() + seed_everything(42) + base_dataset = load_data(args.root) + sample_dataset = set_task(base_dataset) + train_dataloader,val_dataloader,test_dataloader = get_dataloaders( + sample_dataset) + tokenizer = get_tokenizer(args.root,sample_dataset) + encoder_weights = extract_encoder_state_dict(args.root, + args.encoder_chkpt_fname) + + model = define_model( + sample_dataset, + tokenizer, + encoder_weights, + args.model_type) + + trainer = run_trainer( + args.root, + train_dataloader, + val_dataloader, + model, + args.num_epochs) + + print("\n===== Evaluating Test Data ======\n") + evaluate(trainer,test_dataloader) + + + diff --git a/test.py b/test.py index 54904f13..83f6d152 100644 --- a/test.py +++ b/test.py @@ -1,3 +1,4 @@ +""" import pickle from torchvision import transforms from pyhealth.datasets import BaseImageCaptionDataset @@ -46,6 +47,7 @@ def seed_everything(seed: int): root = '/home/keshari2/ChestXrayReporting/IU_XRay/src/data' sample_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay') sample_dataset = sample_dataset.set_task(biview_multisent_fn) + transform = transforms.Compose([ transforms.RandomAffine(degrees=30), transforms.Resize((512,512)), @@ -55,6 +57,7 @@ def seed_everything(seed: int): ]) sample_dataset.set_transform(transform) +""" """ special_tokens = ['','','',''] tokenizer = Tokenizer( @@ -65,6 +68,7 @@ def seed_everything(seed: int): with open(root+'/pyhealth_tokenizer.pkl', 'wb') as f: pickle.dump(tokenizer, f) """ +""" with open(root+'/pyhealth_tokenizer.pkl', 'rb') as f: tokenizer = pickle.load(f) #print(sample_dataset[0]['caption']) @@ -105,4 +109,180 @@ def seed_everything(seed: int): #max_grad_norm = 1, epochs = 5, monitor = 'Bleu_1' -) \ No newline at end of file +) +""" +import os +import argparse +import pickle +import torch +from collections import OrderedDict +from torchvision import transforms +from pyhealth.datasets import BaseImageCaptionDataset +from pyhealth.tasks.xray_report_generation import biview_multisent_fn +from pyhealth.datasets import split_by_patient, get_dataloader +from pyhealth.tokenizer import Tokenizer +from pyhealth.models import WordSAT, SentSAT +from pyhealth.trainer import Trainer + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('--root', type=str, default=".") + parser.add_argument('--encoder-chkpt-fname', type=str, default=None) + parser.add_argument('--tokenizer-fname', type=str, default=None) + parser.add_argument('--num-epochs', type=int, default=1) + parser.add_argument('--model-type', type=str, default="wordsat") + + args = parser.parse_args() + return args + +def seed_everything(seed: int): + import random, os + import numpy as np + import torch + + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + + +# STEP 1: load data +def load_data(root): + base_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay') + return base_dataset + +# STEP 2: set task +def set_task(base_dataset): + sample_dataset = base_dataset.set_task(biview_multisent_fn) + transform = transforms.Compose([ + transforms.RandomAffine(degrees=30), + transforms.Resize((512,512)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485,0.456,0.406], + std=[0.229,0.224,0.225]), + ]) + sample_dataset.set_transform(transform) + return sample_dataset + +# STEP 3: get dataloaders +def get_dataloaders(sample_dataset): + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset,[0.8,0.1,0.1]) + + train_dataloader = get_dataloader(train_dataset,batch_size=8,shuffle=True) + val_dataloader = get_dataloader(val_dataset,batch_size=1,shuffle=False) + test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False) + + return train_dataloader,val_dataloader,test_dataloader + +# STEP 4: get tokenizer +def get_tokenizer(root,sample_dataset=None,tokenizer_fname=None): + if tokenizer_fname: + with open(os.path.join(root,tokenizer_fname), 'wb') as f: + tokenizer = pickle.load(f) + else: + # should always be first element in the list of special tokens + special_tokens = ['','','',''] + tokenizer = Tokenizer( + sample_dataset.get_all_tokens(key='caption'), + special_tokens=special_tokens, + ) + return tokenizer + +# STEP 5: get encoder pretrained state dictionary +def extract_encoder_state_dict(root,chkpt_fname): + checkpoint = torch.load(os.path.join(root,chkpt_fname) ) + state_dict = OrderedDict() + for k,v in checkpoint['state_dict'].items(): + if 'classifier' in k: continue + if k[:7] == 'module.' : + name = k[7:] + else: + name = k + + name = name.replace('classifier.0.','classifier.') + state_dict[name] = v + return state_dict + +# STEP 6: define model +def define_model( + sample_dataset, + tokenizer, + encoder_weights, + model_type='wordsat'): + + if model_type == 'wordsat': + model=WordSAT( + dataset = sample_dataset, + n_input_images = 2, + label_key = 'caption', + tokenizer = tokenizer, + encoder_pretrained_weights = encoder_weights, + encoder_freeze_weights = True, + save_generated_caption = True + ) + else: + model=SentSAT( + dataset = sample_dataset, + n_input_images = 2, + label_key = 'caption', + tokenizer = tokenizer, + encoder_pretrained_weights = encoder_weights, + encoder_freeze_weights = True, + save_generated_caption = True + ) + + return model + +# STEP 7: run trainer +def run_trainer(output_path, train_dataloader, val_dataloader, model,n_epochs): + trainer = Trainer( + model=model, + output_path=output_path + ) + + trainer.train( + train_dataloader = train_dataloader, + val_dataloader = val_dataloader, + optimizer_params = {"lr": 1e-4}, + weight_decay = 1e-5, + epochs = n_epochs, + monitor = 'Bleu_1' + ) + return trainer + +# STEP 8: evaluate +def evaluate(trainer,test_dataloader): + print(trainer.evaluate(test_dataloader)) + return None + +if __name__ == '__main__': + args = get_args() + seed_everything(42) + base_dataset = load_data(args.root) + sample_dataset = set_task(base_dataset) + train_dataloader,val_dataloader,test_dataloader = get_dataloaders( + sample_dataset) + tokenizer = get_tokenizer(args.root,sample_dataset) + encoder_weights = extract_encoder_state_dict(args.root, + args.encoder_chkpt_fname) + + model = define_model( + sample_dataset, + tokenizer, + encoder_weights, + args.model_type) + + trainer = run_trainer( + args.root, + train_dataloader, + val_dataloader, + model, + args.num_epochs) + + print("\n===== Evaluating Test Data ======\n") + evaluate(trainer,test_dataloader) \ No newline at end of file From df4c3944ed095d437c7c43e1b49f99d0fa6bad63 Mon Sep 17 00:00:00 2001 From: samarthkeshari Date: Thu, 27 Apr 2023 13:16:21 -0500 Subject: [PATCH 10/11] updated metrics --- pyhealth/metrics/sequence.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/pyhealth/metrics/sequence.py b/pyhealth/metrics/sequence.py index 571644ce..7a9d8417 100644 --- a/pyhealth/metrics/sequence.py +++ b/pyhealth/metrics/sequence.py @@ -11,21 +11,43 @@ def sequence_metrics_fn( y_generated: List[Dict[int,str]], metrics: Optional[List[str]] = None ) -> Dict[str, float]: - """Compute metrics relevant for evaluating sequences + """Compute metrics relevant for evaluating sequences. + + User can specify which metrics to compute by passing a list of metric names + The accepted metric names are: + - Bleu_{n_grams}: BiLingual Evaluation Understudy. + Allowed n_grams = [1,2,3,4] + - METEOR: Metric for Evaluation of Translation with Explicit ORdering + - ROUGE: Recall-Oriented Understudy for Gisting Evaluation + - CIDEr: Consensus-based Image Description Evaluation + + All metrics compute a score for comparing a candidate text to one or more + reference text. """ scorers = [ (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), (Meteor(), "METEOR"), - (Rouge(), "ROUGE_L"), - (Cider(), "CIDEr"), + (Rouge(), "ROUGE"), + (Cider(), "CIDER"), ] + + allowed_metrics = ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4", + "METEOR","ROUGE","CIDER"] + if metrics: + for metric in metrics: + if metric not in allowed_metrics: + raise ValueError(f"Unknown metric for \ + sequence evaluation: {metric}") + output = {} for scorer, method in scorers: score, scores = scorer.compute_score(y_true, y_generated) if type(score) == list: for m, s in zip(method, score): - output[m] = s + if m in allowed_metrics: + output[m] = s else: - output[method] = score + if m in allowed_metrics: + output[method] = score return output \ No newline at end of file From 8ed1dd5d667142048e68bca66bf3d3b617a233a2 Mon Sep 17 00:00:00 2001 From: samarthkeshari Date: Thu, 27 Apr 2023 13:37:20 -0500 Subject: [PATCH 11/11] updated sequence.py --- pyhealth/metrics/sequence.py | 11 +- test.py | 288 ----------------------------------- 2 files changed, 6 insertions(+), 293 deletions(-) delete mode 100644 test.py diff --git a/pyhealth/metrics/sequence.py b/pyhealth/metrics/sequence.py index 7a9d8417..c98559e5 100644 --- a/pyhealth/metrics/sequence.py +++ b/pyhealth/metrics/sequence.py @@ -36,18 +36,19 @@ def sequence_metrics_fn( if metrics: for metric in metrics: if metric not in allowed_metrics: - raise ValueError(f"Unknown metric for \ - sequence evaluation: {metric}") - + raise ValueError(f"Unknown metric for evaluation: {metric}") + else: + metrics = allowed_metrics + output = {} for scorer, method in scorers: score, scores = scorer.compute_score(y_true, y_generated) if type(score) == list: for m, s in zip(method, score): - if m in allowed_metrics: + if m in metrics: output[m] = s else: - if m in allowed_metrics: + if method in metrics: output[method] = score return output \ No newline at end of file diff --git a/test.py b/test.py deleted file mode 100644 index 83f6d152..00000000 --- a/test.py +++ /dev/null @@ -1,288 +0,0 @@ -""" -import pickle -from torchvision import transforms -from pyhealth.datasets import BaseImageCaptionDataset -from pyhealth.tasks.xray_report_generation import biview_multisent_fn -from pyhealth.datasets import split_by_patient, get_dataloader -from pyhealth.tokenizer import Tokenizer -from pyhealth.models import WordSAT, SentSAT -from pyhealth.trainer import Trainer -from pyhealth.datasets.utils import list_nested_levels, flatten_list - -import torch -from collections import OrderedDict -def extract_state_dict(chkpt_pth): - checkpoint = torch.load(chkpt_pth + 'model_ones_3epoch_densenet.tar') - new_state_dict = OrderedDict() - for k,v in checkpoint['state_dict'].items(): - if 'classifier' in k: continue - if k[:7] == 'module.' : - name = k[7:] - else: - name = k - - name = name.replace('classifier.0.','classifier.') - new_state_dict[name] = v - return new_state_dict - -chkpt_pth = '/home/keshari2/ChestXrayReporting/IU_XRay/src/models/pretrained/' -state_dict = extract_state_dict(chkpt_pth) - -##### -def seed_everything(seed: int): - import random, os - import numpy as np - import torch - - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = True - -seed_everything(42) - -root = '/home/keshari2/ChestXrayReporting/IU_XRay/src/data' -sample_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay') -sample_dataset = sample_dataset.set_task(biview_multisent_fn) - -transform = transforms.Compose([ - transforms.RandomAffine(degrees=30), - transforms.Resize((512,512)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485,0.456,0.406], - std=[0.229,0.224,0.225]), - ]) -sample_dataset.set_transform(transform) - -""" -""" -special_tokens = ['','','',''] -tokenizer = Tokenizer( - sample_dataset.get_all_tokens(key='caption'), - special_tokens=special_tokens, - ) - -with open(root+'/pyhealth_tokenizer.pkl', 'wb') as f: - pickle.dump(tokenizer, f) -""" -""" -with open(root+'/pyhealth_tokenizer.pkl', 'rb') as f: - tokenizer = pickle.load(f) -#print(sample_dataset[0]['caption']) - -train_dataset, val_dataset, test_dataset = split_by_patient( - sample_dataset,[0.8,0.1,0.1] -) - -train_dataloader = get_dataloader(train_dataset,batch_size=16,shuffle=True) -val_dataloader = get_dataloader(val_dataset,batch_size=2,shuffle=False) -test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False) - -model=SentSAT( - dataset=sample_dataset, - n_input_images = 2, - label_key='caption', - tokenizer=tokenizer, - encoder_pretrained_weights=state_dict, - save_generated_caption = True - ) -#model.eval() -#data = next(iter(val_dataloader)) -#print(model(**data)) - -output_path = '/home/keshari2/ChestXrayReporting/IU_XRay/src/output/pyhealth' -ckpt_path = '/home/keshari2/ChestXrayReporting/IU_XRay/src/output/pyhealth/20230424-114914/best.ckpt' - -trainer = Trainer( - model=model, - output_path=output_path, - checkpoint_path = ckpt_path - ) -trainer.train( - train_dataloader = train_dataloader, - val_dataloader = val_dataloader, - optimizer_params = {"lr": 1e-4}, - weight_decay = 1e-5, - #max_grad_norm = 1, - epochs = 5, - monitor = 'Bleu_1' -) -""" -import os -import argparse -import pickle -import torch -from collections import OrderedDict -from torchvision import transforms -from pyhealth.datasets import BaseImageCaptionDataset -from pyhealth.tasks.xray_report_generation import biview_multisent_fn -from pyhealth.datasets import split_by_patient, get_dataloader -from pyhealth.tokenizer import Tokenizer -from pyhealth.models import WordSAT, SentSAT -from pyhealth.trainer import Trainer - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument('--root', type=str, default=".") - parser.add_argument('--encoder-chkpt-fname', type=str, default=None) - parser.add_argument('--tokenizer-fname', type=str, default=None) - parser.add_argument('--num-epochs', type=int, default=1) - parser.add_argument('--model-type', type=str, default="wordsat") - - args = parser.parse_args() - return args - -def seed_everything(seed: int): - import random, os - import numpy as np - import torch - - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = True - - -# STEP 1: load data -def load_data(root): - base_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay') - return base_dataset - -# STEP 2: set task -def set_task(base_dataset): - sample_dataset = base_dataset.set_task(biview_multisent_fn) - transform = transforms.Compose([ - transforms.RandomAffine(degrees=30), - transforms.Resize((512,512)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485,0.456,0.406], - std=[0.229,0.224,0.225]), - ]) - sample_dataset.set_transform(transform) - return sample_dataset - -# STEP 3: get dataloaders -def get_dataloaders(sample_dataset): - train_dataset, val_dataset, test_dataset = split_by_patient( - sample_dataset,[0.8,0.1,0.1]) - - train_dataloader = get_dataloader(train_dataset,batch_size=8,shuffle=True) - val_dataloader = get_dataloader(val_dataset,batch_size=1,shuffle=False) - test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False) - - return train_dataloader,val_dataloader,test_dataloader - -# STEP 4: get tokenizer -def get_tokenizer(root,sample_dataset=None,tokenizer_fname=None): - if tokenizer_fname: - with open(os.path.join(root,tokenizer_fname), 'wb') as f: - tokenizer = pickle.load(f) - else: - # should always be first element in the list of special tokens - special_tokens = ['','','',''] - tokenizer = Tokenizer( - sample_dataset.get_all_tokens(key='caption'), - special_tokens=special_tokens, - ) - return tokenizer - -# STEP 5: get encoder pretrained state dictionary -def extract_encoder_state_dict(root,chkpt_fname): - checkpoint = torch.load(os.path.join(root,chkpt_fname) ) - state_dict = OrderedDict() - for k,v in checkpoint['state_dict'].items(): - if 'classifier' in k: continue - if k[:7] == 'module.' : - name = k[7:] - else: - name = k - - name = name.replace('classifier.0.','classifier.') - state_dict[name] = v - return state_dict - -# STEP 6: define model -def define_model( - sample_dataset, - tokenizer, - encoder_weights, - model_type='wordsat'): - - if model_type == 'wordsat': - model=WordSAT( - dataset = sample_dataset, - n_input_images = 2, - label_key = 'caption', - tokenizer = tokenizer, - encoder_pretrained_weights = encoder_weights, - encoder_freeze_weights = True, - save_generated_caption = True - ) - else: - model=SentSAT( - dataset = sample_dataset, - n_input_images = 2, - label_key = 'caption', - tokenizer = tokenizer, - encoder_pretrained_weights = encoder_weights, - encoder_freeze_weights = True, - save_generated_caption = True - ) - - return model - -# STEP 7: run trainer -def run_trainer(output_path, train_dataloader, val_dataloader, model,n_epochs): - trainer = Trainer( - model=model, - output_path=output_path - ) - - trainer.train( - train_dataloader = train_dataloader, - val_dataloader = val_dataloader, - optimizer_params = {"lr": 1e-4}, - weight_decay = 1e-5, - epochs = n_epochs, - monitor = 'Bleu_1' - ) - return trainer - -# STEP 8: evaluate -def evaluate(trainer,test_dataloader): - print(trainer.evaluate(test_dataloader)) - return None - -if __name__ == '__main__': - args = get_args() - seed_everything(42) - base_dataset = load_data(args.root) - sample_dataset = set_task(base_dataset) - train_dataloader,val_dataloader,test_dataloader = get_dataloaders( - sample_dataset) - tokenizer = get_tokenizer(args.root,sample_dataset) - encoder_weights = extract_encoder_state_dict(args.root, - args.encoder_chkpt_fname) - - model = define_model( - sample_dataset, - tokenizer, - encoder_weights, - args.model_type) - - trainer = run_trainer( - args.root, - train_dataloader, - val_dataloader, - model, - args.num_epochs) - - print("\n===== Evaluating Test Data ======\n") - evaluate(trainer,test_dataloader) \ No newline at end of file