diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..10434d4c --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "editor.rulers": [ + 79 + ] +} \ No newline at end of file diff --git a/examples/xray_report_generation_sat.py b/examples/xray_report_generation_sat.py new file mode 100644 index 00000000..3447db3e --- /dev/null +++ b/examples/xray_report_generation_sat.py @@ -0,0 +1,178 @@ +import os +import argparse +import pickle +import torch +from collections import OrderedDict +from torchvision import transforms +from pyhealth.datasets import BaseImageCaptionDataset +from pyhealth.tasks.xray_report_generation import biview_multisent_fn +from pyhealth.datasets import split_by_patient, get_dataloader +from pyhealth.tokenizer import Tokenizer +from pyhealth.models import WordSAT, SentSAT +from pyhealth.trainer import Trainer + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('--root', type=str, default=".") + parser.add_argument('--encoder-chkpt-fname', type=str, default=None) + parser.add_argument('--tokenizer-fname', type=str, default=None) + parser.add_argument('--num-epochs', type=int, default=1) + parser.add_argument('--model-type', type=str, default="wordsat") + + args = parser.parse_args() + return args + +def seed_everything(seed: int): + import random, os + import numpy as np + import torch + + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + + +# STEP 1: load data +def load_data(root): + base_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay') + return base_dataset + +# STEP 2: set task +def set_task(base_dataset): + sample_dataset = base_dataset.set_task(biview_multisent_fn) + transform = transforms.Compose([ + transforms.RandomAffine(degrees=30), + transforms.Resize((512,512)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485,0.456,0.406], + std=[0.229,0.224,0.225]), + ]) + sample_dataset.set_transform(transform) + return sample_dataset + +# STEP 3: get dataloaders +def get_dataloaders(sample_dataset): + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset,[0.8,0.1,0.1]) + + train_dataloader = get_dataloader(train_dataset,batch_size=8,shuffle=True) + val_dataloader = get_dataloader(val_dataset,batch_size=1,shuffle=False) + test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False) + + return train_dataloader,val_dataloader,test_dataloader + +# STEP 4: get tokenizer +def get_tokenizer(root,sample_dataset=None,tokenizer_fname=None): + if tokenizer_fname: + with open(os.path.join(root,tokenizer_fname), 'wb') as f: + tokenizer = pickle.load(f) + else: + # should always be first element in the list of special tokens + special_tokens = ['','','',''] + tokenizer = Tokenizer( + sample_dataset.get_all_tokens(key='caption'), + special_tokens=special_tokens, + ) + return tokenizer + +# STEP 5: get encoder pretrained state dictionary +def extract_encoder_state_dict(root,chkpt_fname): + checkpoint = torch.load(os.path.join(root,chkpt_fname) ) + state_dict = OrderedDict() + for k,v in checkpoint['state_dict'].items(): + if 'classifier' in k: continue + if k[:7] == 'module.' : + name = k[7:] + else: + name = k + + name = name.replace('classifier.0.','classifier.') + state_dict[name] = v + return state_dict + +# STEP 6: define model +def define_model( + sample_dataset, + tokenizer, + encoder_weights, + model_type='wordsat'): + + if model_type == 'wordsat': + model=WordSAT( + dataset = sample_dataset, + n_input_images = 2, + label_key = 'caption', + tokenizer = tokenizer, + encoder_pretrained_weights = encoder_weights, + encoder_freeze_weights = True, + save_generated_caption = True + ) + else: + model=SentSAT( + dataset = sample_dataset, + n_input_images = 2, + label_key = 'caption', + tokenizer = tokenizer, + encoder_pretrained_weights = encoder_weights, + encoder_freeze_weights = True, + save_generated_caption = True + ) + + return model + +# STEP 7: run trainer +def run_trainer(output_path, train_dataloader, val_dataloader, model,n_epochs): + trainer = Trainer( + model=model, + output_path=output_path + ) + + trainer.train( + train_dataloader = train_dataloader, + val_dataloader = val_dataloader, + optimizer_params = {"lr": 1e-4}, + weight_decay = 1e-5, + epochs = n_epochs, + monitor = 'Bleu_1' + ) + return trainer + +# STEP 8: evaluate +def evaluate(trainer,test_dataloader): + print(trainer.evaluate(test_dataloader)) + return None + +if __name__ == '__main__': + args = get_args() + seed_everything(42) + base_dataset = load_data(args.root) + sample_dataset = set_task(base_dataset) + train_dataloader,val_dataloader,test_dataloader = get_dataloaders( + sample_dataset) + tokenizer = get_tokenizer(args.root,sample_dataset) + encoder_weights = extract_encoder_state_dict(args.root, + args.encoder_chkpt_fname) + + model = define_model( + sample_dataset, + tokenizer, + encoder_weights, + args.model_type) + + trainer = run_trainer( + args.root, + train_dataloader, + val_dataloader, + model, + args.num_epochs) + + print("\n===== Evaluating Test Data ======\n") + evaluate(trainer,test_dataloader) + + + diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 650178e0..28ef3a71 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -1,4 +1,5 @@ from .base_ehr_dataset import BaseEHRDataset +from .base_image_caption_dataset import BaseImageCaptionDataset from .base_signal_dataset import BaseSignalDataset from .eicu import eICUDataset from .mimic3 import MIMIC3Dataset @@ -7,6 +8,7 @@ from .sleepedf import SleepEDFDataset from .isruc import ISRUCDataset from .shhs import SHHSDataset -from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset +from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset,\ + SampleImageCaptionDataset from .splitter import split_by_patient, split_by_visit from .utils import collate_fn_dict, get_dataloader, strptime diff --git a/pyhealth/datasets/base_image_caption_dataset.py b/pyhealth/datasets/base_image_caption_dataset.py new file mode 100644 index 00000000..508c11d4 --- /dev/null +++ b/pyhealth/datasets/base_image_caption_dataset.py @@ -0,0 +1,152 @@ +import logging +import os +from abc import ABC +from typing import Optional, Callable + +import pandas as pd +from tqdm import tqdm + +from pyhealth.datasets.sample_dataset import SampleImageCaptionDataset + +logger = logging.getLogger(__name__) + +INFO_MSG = """ +dataset.patients: + - key: patient id + - value: a dict of image paths, caption, and other information +""" + + +class BaseImageCaptionDataset(ABC): + """Abstract base Image Caption Generation dataset class. + + This abstract class defines a uniform interface for all + image caption generation datasets. + + Each specific dataset will be a subclass of this abstract class, which can + then be converted to samples dataset for different tasks by calling + `self.set_task()`. + + Args: + root: root directory of the raw data (should contain many csv files). + dataset_name: name of the dataset. Default is the name of the class. + dev: whether to enable dev mode (only use a small subset of the data). + Default is False. + refresh_cache: whether to refresh the cache; if true, the dataset will + be processed from scratch and the cache will be updated. + Default is False. + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + dev: bool = False, + refresh_cache: bool = False, + ): + # base attributes + self.dataset_name = ( + self.__class__.__name__ if dataset_name is None else dataset_name + ) + self.root = root + # TODO: dev seems unnecessary for image and signal? + self.dev = dev + if dev: + logger.warning("WARNING: dev has no effect \ + for image caption generation datasets.") + # TODO: refresh_cache seems unnecessary for image and signal? + self.refresh_cache = refresh_cache + if refresh_cache: + logger.warning("WARNING: refresh_cache has no effect \ + for image caption generation datasets.") + + self.metadata = pd.read_json(os.path.join(root, + "metadata.jsonl"), lines=True) + if "patient_id" not in self.metadata.columns: + # no patient_id in metadata, sequentially assign patient_id + self.metadata["patient_id"] = self.metadata.index + + # group by patient_id + self.patients = dict() + for patient_id, group in self.metadata.groupby("patient_id"): + self.patients[patient_id] = group.to_dict(orient="records") + + return + + def __len__(self): + return len(self.patients) + + def __str__(self): + """Prints some information of the dataset.""" + return f"Base dataset {self.dataset_name}" + + def stat(self) -> str: + """Returns some statistics of the base dataset.""" + lines = list() + lines.append("") + lines.append(f"Statistics of base dataset (dev={self.dev}):") + lines.append(f"\t- Dataset: {self.dataset_name}") + lines.append(f"\t- Number of images: {len(self)}") + lines.append("") + print("\n".join(lines)) + return "\n".join(lines) + + @staticmethod + def info(): + """Prints the output format.""" + print(INFO_MSG) + + def set_task( + self, + task_fn: Callable, + task_name: Optional[str] = None, + ) -> SampleImageCaptionDataset: + """Processes the base dataset to generate the task-specific + sample dataset. + + This function should be called by the user after the base dataset is + initialized. It will iterate through all patients in the base dataset + and call `task_fn` which should be implemented by the specific task. + + Args: + task_fn: a function that takes a single patient and returns a + list of samples (each sample is a dict with patient_id, + image_path_list, caption and other task-specific attributes + as key). The samples will be concatenated to form the + sample dataset. + task_name: the name of the task. If None, the name of the task + function will be used. + + Returns: + sample_dataset: the task-specific sample (Base) dataset. + + Note: + In `task_fn`, a patient may have one or multiple images associated + to a caption, for e.g. a patient can have single report + for xrays taken from diffrent views that may be combined to + have a single sample such as + ( + {'patient_id': 1, + 'image_path_list': [frontal_img_path,lateral_img_path], + 'caption': 'report_text'} + ) + Patients can also be excluded from the task dataset by + returning an empty list. + """ + if task_name is None: + task_name = task_fn.__name__ + + # load from raw data + logger.debug(f"Processing {self.dataset_name} base dataset...") + + samples = [] + for patient_id, patient in tqdm( + self.patients.items(), desc=f"Generating samples for {task_name}"): + samples.extend(task_fn(patient)) + + sample_dataset = SampleImageCaptionDataset( + samples, + dataset_name=self.dataset_name, + task_name=task_name, + ) + return sample_dataset \ No newline at end of file diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 5c83c0da..79ae5401 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -2,7 +2,9 @@ from typing import Dict, List import pickle +from PIL import Image from torch.utils.data import Dataset +from torchvision import transforms from pyhealth.datasets.utils import list_nested_levels, flatten_list @@ -307,7 +309,7 @@ def _validate(self) -> Dict: - a list of vectors - a list of list of codes - a list of list of vectors - Note that a value is either float, int, or str; a vector is a list of float + Note that a value is either float, int, or str; a vector is a list of float or int; and a code is str. """ # record input information for each key @@ -347,7 +349,7 @@ def _validate(self) -> Dict: ] """ - 4.2. Check type: the basic type of each element should be float, + 4.2. Check type: the basic type of each element should be float, int, or str. """ types = set([type(v) for v in flattened_values]) @@ -498,6 +500,95 @@ def stat(self) -> str: print("\n".join(lines)) return "\n".join(lines) +class SampleImageCaptionDataset(SampleBaseDataset): + """Sample image caption dataset class. + + This class the takes a list of samples as input (either from + `BaseDataset.set_task()` or user-provided input), and provides + a uniform interface for accessing the samples. + + Args: + samples: a list of samples + dataset_name: the name of the dataset. Default is None. + task_name: the name of the task. Default is None. + """ + + def __init__( + self, + samples: List[Dict], + dataset_name: str = "", + task_name: str = "" + ): + super().__init__(samples, dataset_name, task_name) + self.patient_to_index: Dict[str, List[int]] = self._index_patient() + self.type_ = "image_text" + self.input_info: Dict = self._validate() + self.img_transforms = transforms.Compose([ + transforms.ToTensor(), + ]) + + def set_transform(self, img_transforms): + self.img_transforms = img_transforms + + def _index_patient(self) -> Dict[str, List[int]]: + """Helper function which indexes the samples by patient_id. + + Will be called in `self.__init__()`. + Returns: + patient_to_index: Dict[str, int], a dict mapping patient_id to a + list of sample indices. + """ + patient_to_index = {} + for idx, sample in enumerate(self.samples): + patient_to_index.setdefault(sample["patient_id"], []).append(idx) + return patient_to_index + + def _validate(self) -> Dict: + """Helper function which gets the input information of each attribute. + + Will be called in `self.__init__()`. + + Returns: + input_info: Dict, a dict whose keys are the same as the keys in the + samples, and values are the corresponding input information + """ + input_info = {} + # get info + #input_info["image_path"] = {"type": str, "dim": 2} + input_info["caption"] = {"type": str, "dim": 3} + return input_info + + def __getitem__(self, index) -> Dict: + """Returns a sample by index. + + Returns: + Dict, a dict with patient_id, image_{number}, caption, and other + task-specific attributes as key. Conversion of caption to + index/tensor will be done in the model. + """ + sample = self.samples[index] + for i in range(len(sample['image_path_list'])): + image_key = f'image_{i+1}' + image = Image.open(sample["image_path_list"][i]).convert("RGB") + image = self.img_transforms(image) + sample[image_key] = image + return sample + + def stat(self) -> str: + """Returns some statistics of the task-specific dataset.""" + lines = list() + lines.append(f"Statistics of sample dataset:") + lines.append(f"\t- Dataset: {self.dataset_name}") + lines.append(f"\t- Task: {self.task_name}") + lines.append(f"\t- Number of samples: {len(self)}") + num_patients = len(set([sample["patient_id"] for sample in self.samples])) + lines.append(f"\t- Number of patients: {num_patients}") + lines.append(f"\t- Number of records: {len(self)}") + lines.append( + f"\t- Number of samples per patient: {len(self) / num_patients:.4f}" + ) + print("\n".join(lines)) + return "\n".join(lines) if __name__ == "__main__": samples = [ diff --git a/pyhealth/metrics/__init__.py b/pyhealth/metrics/__init__.py index c07c63e4..15e5cb69 100644 --- a/pyhealth/metrics/__init__.py +++ b/pyhealth/metrics/__init__.py @@ -2,3 +2,4 @@ from .drug_recommendation import ddi_rate_score from .multiclass import multiclass_metrics_fn from .multilabel import multilabel_metrics_fn +from .sequence import sequence_metrics_fn diff --git a/pyhealth/metrics/sequence.py b/pyhealth/metrics/sequence.py new file mode 100644 index 00000000..c98559e5 --- /dev/null +++ b/pyhealth/metrics/sequence.py @@ -0,0 +1,54 @@ +from typing import Dict, List, Optional + +from collections import OrderedDict +from pycocoevalcap.bleu.bleu import Bleu +from pycocoevalcap.rouge.rouge import Rouge +from pycocoevalcap.cider.cider import Cider +from pycocoevalcap.meteor.meteor import Meteor + +def sequence_metrics_fn( + y_true: List[Dict[int,str]], + y_generated: List[Dict[int,str]], + metrics: Optional[List[str]] = None +) -> Dict[str, float]: + """Compute metrics relevant for evaluating sequences. + + User can specify which metrics to compute by passing a list of metric names + The accepted metric names are: + - Bleu_{n_grams}: BiLingual Evaluation Understudy. + Allowed n_grams = [1,2,3,4] + - METEOR: Metric for Evaluation of Translation with Explicit ORdering + - ROUGE: Recall-Oriented Understudy for Gisting Evaluation + - CIDEr: Consensus-based Image Description Evaluation + + All metrics compute a score for comparing a candidate text to one or more + reference text. + """ + scorers = [ + (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), + (Meteor(), "METEOR"), + (Rouge(), "ROUGE"), + (Cider(), "CIDER"), + ] + + allowed_metrics = ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4", + "METEOR","ROUGE","CIDER"] + if metrics: + for metric in metrics: + if metric not in allowed_metrics: + raise ValueError(f"Unknown metric for evaluation: {metric}") + else: + metrics = allowed_metrics + + output = {} + for scorer, method in scorers: + score, scores = scorer.compute_score(y_true, y_generated) + if type(score) == list: + for m, s in zip(method, score): + if m in metrics: + output[m] = s + else: + if method in metrics: + output[method] = score + + return output \ No newline at end of file diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index af8c8192..35662eb3 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -22,3 +22,5 @@ from .grasp import GRASP, GRASPLayer from .stagenet import StageNet, StageNetLayer from .tcn import TCN, TCNLayer +from .wordsat import WordSAT, WordSATEncoder,WordSATDecoder, WordSATAttention +from .sentsat import SentSAT \ No newline at end of file diff --git a/pyhealth/models/base_model.py b/pyhealth/models/base_model.py index a6d844ff..1f87d244 100644 --- a/pyhealth/models/base_model.py +++ b/pyhealth/models/base_model.py @@ -10,7 +10,7 @@ from pyhealth.tokenizer import Tokenizer # TODO: add support for regression -VALID_MODE = ["binary", "multiclass", "multilabel"] +VALID_MODE = ["binary", "multiclass", "multilabel", "sequence"] class BaseModel(ABC, nn.Module): @@ -22,7 +22,7 @@ class BaseModel(ABC, nn.Module): feature_keys: list of keys in samples to use as features, e.g. ["conditions", "procedures"]. label_key: key in samples to use as label (e.g., "drugs"). - mode: one of "binary", "multiclass", or "multilabel". + mode: one of "binary", "multiclass", "multilabel", or sequence. """ def __init__( @@ -31,6 +31,7 @@ def __init__( feature_keys: List[str], label_key: str, mode: str, + save_generated_caption: bool = False ): super(BaseModel, self).__init__() assert mode in VALID_MODE, f"mode must be one of {VALID_MODE}" @@ -38,6 +39,8 @@ def __init__( self.feature_keys = feature_keys self.label_key = label_key self.mode = mode + if mode == "sequence": + self.save_generated_caption = save_generated_caption # used to query the device of the model self._dummy_param = nn.Parameter(torch.empty(0)) return @@ -232,6 +235,7 @@ def get_loss_function(self) -> Callable: - binary: `F.binary_cross_entropy_with_logits` - multiclass: `F.cross_entropy` - multilabel: `F.binary_cross_entropy_with_logits` + - sequence: `F.cross_entropy` Returns: The default loss function. @@ -242,6 +246,8 @@ def get_loss_function(self) -> Callable: return F.cross_entropy elif self.mode == "multilabel": return F.binary_cross_entropy_with_logits + elif self.mode == "sequence": + return F.cross_entropy else: raise ValueError("Invalid mode: {}".format(self.mode)) diff --git a/pyhealth/models/sentsat.py b/pyhealth/models/sentsat.py new file mode 100644 index 00000000..61334a0f --- /dev/null +++ b/pyhealth/models/sentsat.py @@ -0,0 +1,520 @@ +from typing import List, Tuple, Dict, Optional + +import torch +import torch.nn as nn +from torchvision import models +import torch.nn.functional as F + +from pyhealth.datasets import SampleImageCaptionDataset +from pyhealth.models import BaseModel +from pyhealth.tokenizer import Tokenizer +from pyhealth.datasets.utils import flatten_list + +class SentSATEncoder(nn.Module): + """ SAT CNN(Densenet121) Encoder model""" + def __init__(self): + super().__init__() + self.densenet121 = models.densenet121(weights='DEFAULT') + self.densenet121.classifier = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward propagation. + Extract fixed-length feature vectors from the input image. + + Args: + x: A tensor of transfomed image of size + [batch_size,3,512,512] + Return: + x: A tensor of image feature vectors of size + [batch_size,1024,16,16] + """ + x = self.densenet121.features(x) + x = F.relu(x) + return x + +class SentSATAttention(nn.Module): + """SAT Attention Module + + Computes a set of attention weights based on the current hidden state of + the RNN and the feature vectors from the CNN, which are then used to + compute a weighted average of the feature vectors. + + Args: + k_size: key vector size + v_size: value vector size + affine_dim: affine dimension. Default is 512 + """ + def __init__( + self, + k_size: int, + v_size: int, + affine_dim: int =512): + super().__init__() + self.affine_k = nn.Linear(k_size, affine_dim, bias=False) + self.affine_v = nn.Linear(v_size, affine_dim, bias=False) + self.affine = nn.Linear(affine_dim, 1, bias=False) + + def forward( + self, + k: torch.Tensor, + v: torch.Tensor) -> (torch.Tensor,torch.Tensor): + """Forward propagation + + Args: + k: a tensor of size [batch_size, hidden_dim] + v: a tensor of size [batch_size, spatial_size, hidden_dim] + + Returns: + context: a tensor of size [batch_size, feature_dim] + alpha: a tensor of size [batch_size, spatial_size] + """ + content_v = self.affine_k(k).unsqueeze(1) + self.affine_v(v) + # z: batch size x spatial size + z = self.affine(torch.tanh(content_v)).squeeze(2) + alpha = torch.softmax(z, dim=1) + context = (v * alpha.unsqueeze(2)).sum(dim=1) + + return context, alpha + +class SentSATDecoder(nn.Module): + """ Sentence SAT decoder model that treats a caption as multiple sentences + + An LSTM based model that takes as input the attention-weighted feature + vector and generates a sequence of sentences and followed by words, in each + sentence + + Args: + attention: attention module instance + vocab_size: vocabulary size + n_encoder_inputs: number of image inputs given to the encoder + feature_dim: encoder output feature dimension + embedding_dim: decoder embedding dimension + hidden_dim: LSTM hidden dimension + dropout: dropout rate between [0,1] + """ + def __init__( + self, + attention: object, + vocab_size: int, + n_encoder_inputs: int, + feature_dim: int, + embedding_dim: int, + hidden_dim: int, + dropout: int = 0.5 + ): + super().__init__() + self.vocab_size = vocab_size + self.n_encoder_inputs = n_encoder_inputs + self.feature_dim = feature_dim + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + + self.attend = attention + self.embed = nn.Embedding(self.vocab_size, self.embedding_dim) + self.init_h = nn.Linear(self.n_encoder_inputs * self.feature_dim, + self.hidden_dim) + self.init_c = nn.Linear(self.n_encoder_inputs * self.feature_dim, + hidden_dim) + self.sent_lstm = nn.LSTMCell(self.n_encoder_inputs * self.feature_dim, + hidden_dim) + + self.word_lstm = nn.LSTMCell(self.embedding_dim + self.hidden_dim + + self.n_encoder_inputs * feature_dim, + hidden_dim) + self.fc = nn.Linear(self.hidden_dim, self.vocab_size) + self.dropout = nn.Dropout(dropout) + + def forward( + self, + cnn_features: List[torch.Tensor], + captions: List[torch.Tensor] = None, + update_masks: torch.Tensor = None, + max_sents: int = 10, + max_len: int = 30, + stop_id: int = None) -> torch.Tensor: + + """Forward propagation + + Args: + cnn_features: a list of tensors where each tensor is of + size [batch_size, feature_dim, spatial_size]. + captions: a list of tensors. + updat_masks: a boolean tensor to identify the actual tokens + max_sents: maximum number of sentences that can be generated + max_len: maximum length of training or generated caption + stop_id: token id from vocabulary to stop word generation for a + sentence during inference + + Returns: + logits: a tensor + + """ + batch_size = cnn_features[0].size(0) + if captions is not None: + num_sents = captions.size(1) + seq_len = captions.size(2) + else: + num_sents = max_sents + seq_len = max_len + + cnn_feats_t = [ cnn_feat.view(batch_size, self.feature_dim, -1) \ + .permute(0, 2, 1) + for cnn_feat in cnn_features ] + global_feats = [cnn_feat.mean(dim=(2, 3)) for cnn_feat in cnn_features] + + sent_h = self.init_h(torch.cat(global_feats, dim=1)) + sent_c = self.init_c(torch.cat(global_feats, dim=1)) + + word_h = cnn_features[0].new_zeros((batch_size, + self.hidden_dim),dtype=torch.float) + + word_c = cnn_features[0].new_zeros((batch_size, + self.hidden_dim),dtype=torch.float) + + logits = cnn_features[0].new_zeros((batch_size, + num_sents, + seq_len, + self.vocab_size),dtype=torch.float) + # Training phase + if captions is not None: + embeddings = self.embed(captions) + + for k in range(num_sents): + contexts = [self.attend(sent_h, cnn_feat_t)[0] + for cnn_feat_t in cnn_feats_t] + context = torch.cat(contexts, dim=1) + sent_h, sent_c = self.sent_lstm(context, (sent_h, sent_c)) + seq_len_k = update_masks[:, k].sum(dim=1).max().item() + + for t in range(seq_len_k): + batch_mask = update_masks[:, k, t] + + word_h_, word_c_ = self.word_lstm( + torch.cat((embeddings[batch_mask, k, t], + sent_h[batch_mask], + context[batch_mask]), dim=1), + (word_h[batch_mask], word_c[batch_mask])) + + indices = [*batch_mask.unsqueeze(1). \ + repeat(1, self.hidden_dim).nonzero().t()] + word_h = word_h.index_put(indices, word_h_.view(-1)) + word_c = word_c.index_put(indices, word_c_.view(-1)) + logits[batch_mask, k, t] = self.fc( + self.dropout(word_h[batch_mask])) + + return logits + + # Evaluation/Inference phase + else: + x_t = cnn_features[0].new_full((batch_size,), 1, dtype=torch.long) + + for k in range(num_sents): + contexts = [self.attend(sent_h, cnn_feat_t)[0] + for cnn_feat_t in cnn_feats_t] + context = torch.cat(contexts, dim=1) + sent_h, sent_c = self.sent_lstm(context, (sent_h, sent_c)) + + for t in range(seq_len): + embedding = self.embed(x_t) + word_h, word_c = self.word_lstm( + torch.cat((embedding, sent_h, context), dim=1), + (word_h, word_c)) + logit = self.fc(word_h) + x_t = logit.argmax(dim=1) + logits[:, k, t] = logit + + if x_t[0] == stop_id: + break + + return logits.argmax(dim=3) + +class SentSAT(BaseModel): + """Show Attend & Tell model, treating entire caption as one sentence. + This model is based on Show, Attend, and Tell (SAT) paper. The model uses + convolutional neural networks (CNNs) to encode the image and a + recurrent neural network (RNN) with an attention mechanism to generate + the corresponding caption. + + The model consists of three main components: + - Encoder: The encoder is a CNN that extracts a fixed-length feature + vectors from the input image. + - Attention Mechanism: The attention mechanism is used to select + relevant parts of the image at each time step of the RNN, to + generate the next word in the caption. + - Decoder: The decoder is a language model implemented as an RNN that + takes as input the attention-weighted feature vector and generates + a sequence of words, one at a time. + + Args: + dataset: the dataset to train the model. + n_input_images: number of images passed as input to each sample. + label_key: key in the samples to use as label (e.g., "caption"). + tokenizer: pyhealth tokenizer instance created using sample texts. + encoder_pretrained_weights: pretrained state dictionary for encoder. + Default is None. + encoder_freeze_weights: freeze encoder weights so that they are not + updated during training. This is useful when the encoder is trained + separately as a classifier. Default is True. + decoder_embed_dim: decoder embedding dimesion. Default is 256. + decoder_hidden_dim: decoder hidden state dimension. Default is 512. + decoder_feaure_dim: decoder input cell state dimension. + Default is 1024 + decoder_dropout: decoder dropout rate between [0,1]. Default is 0.5 + attention_affine_dim: output dimension of affine layer in attention. + Default is 512. + save_generated_caption: save the generated caption during training. + This is used for evaluating the quality of generated captions. + Default is False. + """ + + def __init__( + self, + dataset: SampleImageCaptionDataset, + n_input_images: int, + label_key: str, + tokenizer: Tokenizer, + encoder_pretrained_weights: Dict[str,float] = None, + encoder_freeze_weights: bool = True, + decoder_embed_dim: int = 256, + decoder_hidden_dim: int = 512, + decoder_feature_dim: int = 1024, + decoder_dropout: float = 0.5, + attention_affine_dim: int = 512, + save_generated_caption: bool = False, + **kwargs + ): + super(SentSAT, self).__init__( + dataset=dataset, + feature_keys=[f'image_{i+1}' for i in range(n_input_images)], + label_key=label_key, + mode="sequence", + save_generated_caption = save_generated_caption + ) + self.n_input_images = n_input_images + + # Encoder component + self.encoder = SentSATEncoder() + if encoder_pretrained_weights: + print(f'Loading encoder pretrained model') + self.encoder.load_state_dict(encoder_pretrained_weights) + if encoder_freeze_weights: + self.encoder.eval() + + # Attention component + self.attention = SentSATAttention(decoder_hidden_dim, + decoder_feature_dim, + attention_affine_dim) + + # Decoder component + self.caption_tokenizer = tokenizer + vocab_size = self.caption_tokenizer.get_vocabulary_size() + self.decoder = SentSATDecoder( self.attention, + vocab_size, + n_input_images, + decoder_feature_dim, + decoder_embed_dim, + decoder_hidden_dim, + decoder_dropout + ) + + def forward( + self, + decoder_maxsents: int =10, + decoder_maxlen:int = 20, + decoder_stop_id: int = None, + **kwargs) -> Dict[str,str]: + """Forward propagation. + + The features `kwargs[self.feature_keys]` is a list of feature keys + associated to every input image. + + The label `kwargs[self.label_key]` is a key of the report caption + for each patient. + + Args: + **kwargs: keyword arguments for the model. The keys must contain + all the feature keys and the label key. + + Returns: + A dictionary with the following keys: + loss: a scalar tensor representing the loss. + y_generated: a dictionary with following key + - "patient_id": list of text representing the generated + text. + e.g. + {123: ["generated text"], 456: ["generated text"]} + y_true: a dictionary with following key + - "patient_id": list of text representing the true text. + e.g. + {123: ["true text"], 456: ["true text"]} + """ + # Initialize the output + output = {"loss": None,"y_generated": [""],"y_true": [""]} + + # Get list of patient_ids + patient_ids = kwargs["patient_id"] + + # Get CNN features + images = self._prepare_batch_images(kwargs) + cnn_features = [self.encoder(image.to(self.device)) + for image in images] + + if self.training: + # Get caption as indicies and corresponding masks + captions, loss_masks, update_masks=self._prepare_batch_captions( + kwargs[self.label_key]) + + # Perform predictions + logits = self.decoder(cnn_features, + captions[:, :, :-1], + update_masks, + decoder_maxsents, + decoder_maxlen, + decoder_stop_id) + + logits = logits.permute(0, 3, 1, 2).contiguous() + captions = captions[:, :, 1:].contiguous() + loss_masks = loss_masks[:, :, 1:].contiguous() + + # Compute loss + loss = self.get_loss_function()(logits, captions) + loss = loss.masked_select(loss_masks).mean() + output["loss"] = loss + + with torch.no_grad(): + output["y_generated"] = self._forward_inference(patient_ids, + decoder_maxsents, + decoder_maxlen, + cnn_features, + decoder_stop_id) + output["y_true"] = self._forward_ground_truths(patient_ids, + kwargs[self.label_key]) + return output + + def _prepare_batch_images(self,kwargs): + """Prepare images for input. + Args: + kwargs: keyword arguments for the model. + Returns: + images: a list of input images represented as tensors. Every tensor + in the list has shape [batch_size,3,image_size,image_size] + """ + images = [torch.stack(kwargs[image_feature], 0) + for image_feature in self.feature_keys] + + return images + + def _prepare_batch_captions( + self, + captions:List[List[str]] + ) -> (torch.Tensor,torch.Tensor): + """Prepare caption idx for input. + + Args: + captions: list of captions. Each caption is a list of list, where + each list represents a sentence in the caption. + Following is an example of a caption + [ + ["","first","sentence","."], + [".", "second", "sentence",".", ""] + ] + Returns: + captions_idx: an int tensor + loss_masks: a bool tensor for each sentence in a caption + update_masks: a bool tensor for each sentence in a caption + """ + x = self.caption_tokenizer.batch_encode_3d(captions) + captions_idx = torch.tensor(x, dtype=torch.long, + device=self.device) + + loss_masks = torch.zeros_like(captions_idx,dtype=torch.bool) + update_masks = torch.zeros_like(captions_idx,dtype=torch.bool) + + for icap, cap in enumerate(captions_idx): + for isent, sent in enumerate(cap): + l = len(sent[sent !=0]) + if l==0: continue + loss_masks[icap, isent, 1:l].fill_(1) + update_masks[icap, isent, :l-1].fill_(1) + + return captions_idx, loss_masks, update_masks + + def _forward_inference( + self, + patient_ids: List[int], + decoder_maxsents: int, + decoder_maxlen: int, + cnn_features:List[torch.Tensor], + decoder_stop_id: int = None) -> Dict[int,str]: + """Forward propagation during inference + + Args: + patient_ids: a list of patient ids + decoder_maxsents: maximum number of sentences that can be generated + decoder_maxlen: maximum length of words in a every sentence of a + caption + cnn_features: a list of tensors + stop_id: token id from vocabulary to stop word generation for a + sentence + + Returns: + generated_results: a dict with following keys + - patient_id: int + - generated_results: List[str] + """ + generated_results = {} + for idx, patient_id in enumerate(patient_ids): + generated_results[patient_id] = [""] + cnn_feature = [cnn_feat[idx].unsqueeze(0) + for cnn_feat in cnn_features] + pred = self.decoder(cnn_feature, None, None, + decoder_maxsents,decoder_maxlen, + decoder_stop_id)[0] + pred = pred.detach().cpu() + generated_results[patient_id] = [""] + for isent in range(pred.size(0)): + pred_tokens = self.caption_tokenizer \ + .convert_indices_to_tokens(pred[isent].tolist()) + + words = [] + for token in pred_tokens: + if token == '' or token == '': + continue + if token == '': + break + words.append(token) + if len(words) < 2: + continue + generated_results[patient_id][0] += " ".join(words) + generated_results[patient_id][0] += " " + generated_results[patient_id][0] = generated_results[patient_id][0].replace(". .",'.') + + return generated_results + + def _forward_ground_truths( + self, + patient_ids: List[int], + captions:List[List[str]]) -> Dict[int,str]: + """Forward propagation for ground truth + + Args: + patient_ids: a list of patient ids + cnn_features: a list of tensors + + Returns: + ground_results: a dict with following keys + - patient_id: int + - generated_results: List[str] + """ + ground_truths = {} + for idx, caption in enumerate(captions): + ground_truths[patient_ids[idx]] = [""] + tokens = [] + tokens.extend(flatten_list(caption)) + ground_truths[patient_ids[idx]][0] = ' '.join(tokens) \ + .replace('. .','.') \ + .replace("","") \ + .replace("","") \ + .strip() + + return ground_truths \ No newline at end of file diff --git a/pyhealth/models/wordsat.py b/pyhealth/models/wordsat.py new file mode 100644 index 00000000..3ac42b4f --- /dev/null +++ b/pyhealth/models/wordsat.py @@ -0,0 +1,457 @@ +from typing import List, Tuple, Dict, Optional + +import torch +import torch.nn as nn +from torchvision import models +import torch.nn.functional as F + +from pyhealth.datasets import SampleImageCaptionDataset +from pyhealth.models import BaseModel +from pyhealth.tokenizer import Tokenizer +from pyhealth.datasets.utils import flatten_list + +class WordSATEncoder(nn.Module): + """ SAT CNN(Densenet121) Encoder model""" + def __init__(self): + super().__init__() + self.densenet121 = models.densenet121(weights='DEFAULT') + self.densenet121.classifier = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward propagation. + Extract fixed-length feature vectors from the input image. + + Args: + x: A tensor of transfomed image of size + [batch_size,3,512,512] + Return: + x: A tensor of image feature vectors of size + [batch_size,1024,16,16] + """ + x = self.densenet121.features(x) + x = F.relu(x) + return x + +class WordSATAttention(nn.Module): + """SAT Attention Module + + Computes a set of attention weights based on the current hidden state of + the RNN and the feature vectors from the CNN, which are then used to + compute a weighted average of the feature vectors. + + Args: + k_size: key vector size + v_size: value vector size + affine_dim: affine dimension. Default is 512 + """ + def __init__( + self, + k_size: int, + v_size: int, + affine_dim: int =512): + super().__init__() + self.affine_k = nn.Linear(k_size, affine_dim, bias=False) + self.affine_v = nn.Linear(v_size, affine_dim, bias=False) + self.affine = nn.Linear(affine_dim, 1, bias=False) + + def forward( + self, + k: torch.Tensor, + v: torch.Tensor) -> (torch.Tensor,torch.Tensor): + """Forward propagation + + Args: + k: a tensor of size [batch_size, hidden_dim] + v: a tensor of size [batch_size, spatial_size, hidden_dim] + + Returns: + context: a tensor of size [batch_size, feature_dim] + alpha: a tensor of size [batch_size, spatial_size] + """ + content_v = self.affine_k(k).unsqueeze(1) + self.affine_v(v) + # z: batch size x spatial size + z = self.affine(torch.tanh(content_v)).squeeze(2) + alpha = torch.softmax(z, dim=1) + context = (v * alpha.unsqueeze(2)).sum(dim=1) + + return context, alpha + +class WordSATDecoder(nn.Module): + """ Word SAT decoder model for one sentence + + An LSTM based model that takes as input the attention-weighted feature + vector and generates a sequence of words, one at a time. + + Args: + attention: attention module instance + vocab_size: vocabulary size + n_encoder_inputs: number of image inputs given to the encoder + feature_dim: encoder output feature dimension + embedding_dim: decoder embedding dimension + hidden_dim: LSTM hidden dimension + dropout: dropout rate between [0,1] + """ + def __init__( + self, + attention: object, + vocab_size: int, + n_encoder_inputs: int, + feature_dim: int, + embedding_dim: int, + hidden_dim: int, + dropout: int = 0.5 + ): + super().__init__() + self.vocab_size = vocab_size + self.n_encoder_inputs = n_encoder_inputs + self.feature_dim = feature_dim + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + + self.attend = attention + self.embed = nn.Embedding(self.vocab_size, self.embedding_dim) + self.init_h = nn.Linear(self.n_encoder_inputs * self.feature_dim, + self.hidden_dim) + self.init_c = nn.Linear(self.n_encoder_inputs * feature_dim, + hidden_dim) + self.lstmcell = nn.LSTMCell(self.embedding_dim + + self.n_encoder_inputs * feature_dim, + hidden_dim) + self.fc = nn.Linear(self.hidden_dim, self.vocab_size) + self.dropout = nn.Dropout(dropout) + + def forward( + self, + cnn_features: List[torch.Tensor], + captions: List[torch.Tensor] = None, + max_len: int = 100) -> torch.Tensor: + + """Forward propagation + + Args: + cnn_features: a list of tensors where each tensor is of + size [batch_size, feature_dim, spatial_size]. + captions: a list of tensors. + max_len: maximum length of training or generated caption + + Returns: + logits: a tensor + + """ + batch_size = cnn_features[0].size(0) + if captions is not None: + seq_len = captions.size(1) + else: + seq_len = max_len + + cnn_feats_t = [ cnn_feat.view(batch_size, self.feature_dim, -1) \ + .permute(0, 2, 1) + for cnn_feat in cnn_features ] + global_feats = [cnn_feat.mean(dim=(2, 3)) for cnn_feat in cnn_features] + + h = self.init_h(torch.cat(global_feats, dim=1)) + c = self.init_c(torch.cat(global_feats, dim=1)) + + logits = cnn_features[0].new_zeros((batch_size, + seq_len, + self.vocab_size),dtype=torch.float) + # Training phase + if captions is not None: + embeddings = self.embed(captions) + for t in range(seq_len): + contexts = [self.attend(h, cnn_feat_t)[0] + for cnn_feat_t in cnn_feats_t] + context = torch.cat(contexts, dim=1) + h, c = self.lstmcell(torch.cat((embeddings[:, t], context), + dim=1), + (h, c)) + logits[:, t] = self.fc(self.dropout(h)) + + return logits + + # Evaluation/Inference phase + else: + x_t = cnn_features[0].new_full((batch_size,), 1, dtype=torch.long) + for t in range(seq_len): + embedding = self.embed(x_t) + contexts = [self.attend(h, cnn_feat_t)[0] + for cnn_feat_t in cnn_feats_t] + context = torch.cat(contexts, dim=1) + h, c = self.lstmcell(torch.cat((embedding, context), dim=1), + (h, c)) + logit = self.fc(h) + x_t = logit.argmax(dim=1) + logits[:, t] = logit + + return logits.argmax(dim=2) + +class WordSAT(BaseModel): + """Show Attend & Tell model, treating entire caption as one sentence. + This model is based on Show, Attend, and Tell (SAT) paper. The model uses + convolutional neural networks (CNNs) to encode the image and a + recurrent neural network (RNN) with an attention mechanism to generate + the corresponding caption. + + The model consists of three main components: + - Encoder: The encoder is a CNN that extracts a fixed-length feature + vectors from the input image. + - Attention Mechanism: The attention mechanism is used to select + relevant parts of the image at each time step of the RNN, to + generate the next word in the caption. + - Decoder: The decoder is a language model implemented as an RNN that + takes as input the attention-weighted feature vector and generates + a sequence of words, one at a time. + + Args: + dataset: the dataset to train the model. + n_input_images: number of images passed as input to each sample. + label_key: key in the samples to use as label (e.g., "caption"). + tokenizer: pyhealth tokenizer instance created using sample texts. + encoder_pretrained_weights: pretrained state dictionary for encoder. + Default is None. + encoder_freeze_weights: freeze encoder weights so that they are not + updated during training. This is useful when the encoder is trained + separately as a classifier. Default is True. + decoder_embed_dim: decoder embedding dimesion. Default is 256. + decoder_hidden_dim: decoder hidden state dimension. Default is 512. + decoder_feaure_dim: decoder input cell state dimension. + Default is 1024 + decoder_dropout: decoder dropout rate between [0,1]. Default is 0.5 + attention_affine_dim: output dimension of affine layer in attention. + Default is 512. + save_generated_caption: save the generated caption during training. + This is used for evaluating the quality of generated captions. + Default is False. + """ + + def __init__( + self, + dataset: SampleImageCaptionDataset, + n_input_images: int, + label_key: str, + tokenizer: Tokenizer, + encoder_pretrained_weights: Dict[str,float] = None, + encoder_freeze_weights: bool = True, + decoder_embed_dim: int = 256, + decoder_hidden_dim: int = 512, + decoder_feature_dim: int = 1024, + decoder_dropout: float = 0.5, + attention_affine_dim: int = 512, + save_generated_caption: bool = False, + **kwargs + ): + super(WordSAT, self).__init__( + dataset=dataset, + feature_keys=[f'image_{i+1}' for i in range(n_input_images)], + label_key=label_key, + mode="sequence", + save_generated_caption = save_generated_caption + ) + self.n_input_images = n_input_images + self.save_generated_caption = save_generated_caption + + # Encoder component + self.encoder = WordSATEncoder() + if encoder_pretrained_weights: + print(f'Loading encoder pretrained model') + self.encoder.load_state_dict(encoder_pretrained_weights) + if encoder_freeze_weights: + self.encoder.eval() + + # Attention component + self.attention = WordSATAttention(decoder_hidden_dim, + decoder_feature_dim, + attention_affine_dim) + + # Decoder component + self.caption_tokenizer = tokenizer + vocab_size = self.caption_tokenizer.get_vocabulary_size() + self.decoder = WordSATDecoder( self.attention, + vocab_size, + n_input_images, + decoder_feature_dim, + decoder_embed_dim, + decoder_hidden_dim, + decoder_dropout + ) + + def forward(self, decoder_maxlen:int = 100, **kwargs) -> Dict[str,str]: + """Forward propagation. + + The features `kwargs[self.feature_keys]` is a list of feature keys + associated to every input image. + + The label `kwargs[self.label_key]` is a key of the report caption + for each patient. + + Args: + decder_maxlen: maximum caption length used during training or + generated during inference. Default is 100. + **kwargs: keyword arguments for the model. The keys must contain + all the feature keys and the label key. + + Returns: + A dictionary with the following keys: + loss: a scalar tensor representing the loss. + y_generated: a dictionary with following key + - "patient_id": list of text representing the generated + text. + e.g. + {123: ["generated text"], 456: ["generated text"]} + y_true: a dictionary with following key + - "patient_id": list of text representing the true text. + e.g. + {123: ["true text"], 456: ["true text"]} + """ + # Initialize the output + output = {"loss": None,"y_generated": [""],"y_true": [""]} + + # Get list of patient_ids + patient_ids = kwargs["patient_id"] + + # Get CNN features + images = self._prepare_batch_images(kwargs) + cnn_features = [self.encoder(image.to(self.device)) + for image in images] + + if self.training: + # Get caption as indicies and corresponding masks + captions,masks=self._prepare_batch_captions(kwargs[self.label_key]) + + # Perform predictions + logits = self.decoder(cnn_features, captions[:,:-1], + decoder_maxlen) + logits = logits.permute(0, 2, 1).contiguous() + captions = captions[:, 1:].contiguous() + masks = masks[:, 1:].contiguous() + + # Compute loss + loss = self.get_loss_function()(logits, captions) + loss = loss.masked_select(masks).mean() + output["loss"] = loss + + with torch.no_grad(): + output["y_generated"] = self._forward_inference(patient_ids, + decoder_maxlen, + cnn_features) + output["y_true"] = self._forward_ground_truths(patient_ids, + kwargs[self.label_key]) + return output + + def _prepare_batch_images(self,kwargs): + """Prepare images for input. + Args: + kwargs: keyword arguments for the model. + Returns: + images: a list of input images represented as tensors. Every tensor + in the list has shape [batch_size,3,image_size,image_size] + """ + images = [torch.stack(kwargs[image_feature], 0) + for image_feature in self.feature_keys] + + return images + + def _prepare_batch_captions( + self, + captions:List[List[str]] + ) -> (torch.Tensor,torch.Tensor): + """Prepare caption idx for input. + + Args: + captions: list of captions. Each caption is a list of list, where + each list represents a sentence in the caption. + Following is an example of a caption + [ + ["","first","sentence","."], + [".", "second", "sentence",".", ""] + ] + Returns: + captions_idx: an int tensor + masks: a bool tensor + """ + + # Combine all sentences in each caption to create a single sentence + samples = [] + for caption in captions: + tokens = [] + tokens.extend(flatten_list(caption)) + text = ' '.join(tokens).replace('. .','.') + samples.append([text.split()]) + + x = self.caption_tokenizer.batch_encode_3d(samples) + captions_idx = torch.tensor(x, dtype=torch.long, + device=self.device) + masks = torch.sum(captions_idx,dim=1) !=0 + captions_idx = captions_idx.squeeze(1) + + return captions_idx,masks + + def _forward_inference( + self, + patient_ids: List[int], + decoder_maxlen: int, + cnn_features:List[torch.Tensor]) -> Dict[int,str]: + """Forward propagation during inference + + Args: + patient_ids: a list of patient ids + decoder_maxlen: maximum length of generated caption + cnn_features: a list of tensors + + Returns: + generated_results: a dict with following keys + - patient_id: int + - generated_results: List[str] + """ + generated_results = {} + for idx, patient_id in enumerate(patient_ids): + generated_results[patient_id] = [""] + cnn_feature = [cnn_feat[idx].unsqueeze(0) + for cnn_feat in cnn_features] + pred = self.decoder(cnn_feature, None, decoder_maxlen)[0] + pred = pred.detach().cpu() + pred_tokens = self.caption_tokenizer \ + .convert_indices_to_tokens(pred.tolist()) + generated_results[patient_id] = [""] + words = [] + for token in pred_tokens: + if token == '' or token == '': + continue + if token == '': + break + words.append(token) + + generated_results[patient_id][0] = " ".join(words) + + return generated_results + + def _forward_ground_truths( + self, + patient_ids: List[int], + captions:List[List[str]]) -> Dict[int,str]: + """Forward propagation for ground truth + + Args: + patient_ids: a list of patient ids + cnn_features: a list of tensors + + Returns: + ground_results: a dict with following keys + - patient_id: int + - generated_results: List[str] + """ + ground_truths = {} + for idx, caption in enumerate(captions): + ground_truths[patient_ids[idx]] = [""] + tokens = [] + tokens.extend(flatten_list(caption)) + ground_truths[patient_ids[idx]][0] = ' '.join(tokens) \ + .replace('. .','.') \ + .replace("","") \ + .replace("","") \ + .strip() + + return ground_truths + + + diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 6d7ffb02..e98057c4 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -26,3 +26,6 @@ sleep_staging_sleepedf_fn, sleep_staging_isruc_fn ) +from .xray_report_generation import ( + biview_multisent_fn +) diff --git a/pyhealth/tasks/xray_report_generation.py b/pyhealth/tasks/xray_report_generation.py new file mode 100644 index 00000000..d1fcc6fa --- /dev/null +++ b/pyhealth/tasks/xray_report_generation.py @@ -0,0 +1,69 @@ +import os +import string + +def biview_multisent_fn(patient): + """ Processes single patient for X-ray report generation task + + Xray report generation aims a automatically generating the diagnosis report + from X-ray images taken from two viewpoints namely - frontal and lateral. + + An X-ray report generally consists of following elements + - History, describing the reason for the X-ray + - Technique, describing the type of X-ray performed - frontal,lateral + - Findings, describing the observations made by the radiologist who + interpreted the X-ray images + - Impression, describing the overall interpretation of X-ray results + + Args: + patient: a list of dictionary of patient X-ray report with below keys + - patient_id: type(int) + unique identifier for patient + - frontal_img_path: type(str) + path to frontal X-ray image + - lateral_img_path: type(str) + path to lateral X-ray image + - findings: type(str) + text of X-ray report findings + - impression: type(str) + text of X-ray report impression + + Returns: + sample: a list of one sample, each sample is a dict with following keys + - patient_id: type(int) + unique identifier for patient + - image_path_list: type(List) + list of frontal and lateral image paths + - caption: type(List[List]) + nested list of sentences,where each inner list represents a + single sentence in the X-ray report text(formed by + concatenating impression and findings). + + Note: special tokens "" and "", are added to the begining of + first and end of last sentence. These are mandatory for the task. + """ + sample = {} + patient = patient[0] + + report = f"{patient['impression']} . {patient['findings']}" + caption = [] + sents = report.lower().split(".") + sents = [sent for sent in sents if len(sent.strip()) > 1] + + for isent, sent in enumerate(sents): + tokens = sent.translate(str.maketrans("", "", string.punctuation)) \ + .strip() \ + .split() + caption.append([".", *[token for token in tokens],"."]) + + if caption == []: + caption = [["",""]] + else: + caption[0][0] = "" + caption[-1].append("") + + sample["patient_id"] = int(patient["patient_id"]) + sample["image_path_list"] = [ patient["frontal_img_path"], + patient["lateral_img_path"], + ] + sample["caption"] = caption + return [sample] \ No newline at end of file diff --git a/pyhealth/trainer.py b/pyhealth/trainer.py index fd290fc0..84579bf4 100644 --- a/pyhealth/trainer.py +++ b/pyhealth/trainer.py @@ -12,7 +12,7 @@ from tqdm.autonotebook import trange from pyhealth.metrics import (binary_metrics_fn, multiclass_metrics_fn, - multilabel_metrics_fn) + multilabel_metrics_fn,sequence_metrics_fn) from pyhealth.utils import create_directory logger = logging.getLogger(__name__) @@ -44,6 +44,8 @@ def get_metrics_fn(mode: str) -> Callable: return multiclass_metrics_fn elif mode == "multilabel": return multilabel_metrics_fn + elif mode == "sequence": + return sequence_metrics_fn else: raise ValueError(f"Mode {mode} is not supported") @@ -176,6 +178,7 @@ def train( # epoch training loop for epoch in range(epochs): + self.current_epoch = epoch training_loss = [] self.model.zero_grad() self.model.train() @@ -257,6 +260,8 @@ def inference(self, dataloader, additional_outputs=None) -> Dict[str, float]: loss_mean: Mean loss over batches. additional_outputs (only if requested): Dict of additional results. """ + if self.model.mode == "sequence": + return self.inference_sequence(dataloader) loss_all = [] y_true_all = [] y_prob_all = [] @@ -298,7 +303,8 @@ def evaluate(self, dataloader) -> Dict[str, float]: mode = self.model.mode metrics_fn = get_metrics_fn(mode) scores = metrics_fn(y_true_all, y_prob_all, metrics=self.metrics) - scores["loss"] = loss_mean + if mode != "sequence": + scores["loss"] = loss_mean return scores def save_ckpt(self, ckpt_path: str) -> None: @@ -313,6 +319,34 @@ def load_ckpt(self, ckpt_path: str) -> None: self.model.load_state_dict(state_dict) return + def inference_sequence(self, dataloader) -> Dict[int, str]: + """Model inference for sequences + """ + y_true_all = {} + y_pred_all = {} + for data in tqdm(dataloader, desc="Evaluation"): + self.model.eval() + with torch.no_grad(): + output = self.model(**data) + y_true = output["y_true"] + y_generated = output["y_generated"] + for key in y_generated.keys(): + y_true_all[key] = y_true[key] + y_pred_all[key] = y_generated[key] + + if self.model.save_generated_caption: + fname = datetime.now().strftime("%Y%m%d-%H%M%S") + with open(os.path.join(self.exp_path, + f"{fname}_gen{self.current_epoch}.csv"),"w") as f1: + with open(os.path.join(self.exp_path, + f"{fname}_gts.csv"), "w") as f2: + for patient_id in y_pred_all.keys(): + f1.write(y_pred_all[patient_id][0] + '\n') + f2.write(y_true_all[patient_id][0] + '\n') + + + return y_true_all, y_pred_all, 0 + if __name__ == "__main__": import torch