Skip to content

Commit

Permalink
added exmple code for xray report generation
Browse files Browse the repository at this point in the history
  • Loading branch information
samarthkeshari committed Apr 27, 2023
1 parent d4aad7f commit 6e1db61
Show file tree
Hide file tree
Showing 2 changed files with 359 additions and 1 deletion.
178 changes: 178 additions & 0 deletions examples/xray_report_generation_sat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
import argparse
import pickle
import torch
from collections import OrderedDict
from torchvision import transforms
from pyhealth.datasets import BaseImageCaptionDataset
from pyhealth.tasks.xray_report_generation import biview_multisent_fn
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.tokenizer import Tokenizer
from pyhealth.models import WordSAT, SentSAT
from pyhealth.trainer import Trainer

def get_args():
parser = argparse.ArgumentParser()

parser.add_argument('--root', type=str, default=".")
parser.add_argument('--encoder-chkpt-fname', type=str, default=None)
parser.add_argument('--tokenizer-fname', type=str, default=None)
parser.add_argument('--num-epochs', type=int, default=1)
parser.add_argument('--model-type', type=str, default="wordsat")

args = parser.parse_args()
return args

def seed_everything(seed: int):
import random, os
import numpy as np
import torch

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True


# STEP 1: load data
def load_data(root):
base_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay')
return base_dataset

# STEP 2: set task
def set_task(base_dataset):
sample_dataset = base_dataset.set_task(biview_multisent_fn)
transform = transforms.Compose([
transforms.RandomAffine(degrees=30),
transforms.Resize((512,512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]),
])
sample_dataset.set_transform(transform)
return sample_dataset

# STEP 3: get dataloaders
def get_dataloaders(sample_dataset):
train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset,[0.8,0.1,0.1])

train_dataloader = get_dataloader(train_dataset,batch_size=8,shuffle=True)
val_dataloader = get_dataloader(val_dataset,batch_size=1,shuffle=False)
test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False)

return train_dataloader,val_dataloader,test_dataloader

# STEP 4: get tokenizer
def get_tokenizer(root,sample_dataset=None,tokenizer_fname=None):
if tokenizer_fname:
with open(os.path.join(root,tokenizer_fname), 'wb') as f:
tokenizer = pickle.load(f)
else:
# <pad> should always be first element in the list of special tokens
special_tokens = ['<pad>','<start>','<end>','<unk>']
tokenizer = Tokenizer(
sample_dataset.get_all_tokens(key='caption'),
special_tokens=special_tokens,
)
return tokenizer

# STEP 5: get encoder pretrained state dictionary
def extract_encoder_state_dict(root,chkpt_fname):
checkpoint = torch.load(os.path.join(root,chkpt_fname) )
state_dict = OrderedDict()
for k,v in checkpoint['state_dict'].items():
if 'classifier' in k: continue
if k[:7] == 'module.' :
name = k[7:]
else:
name = k

name = name.replace('classifier.0.','classifier.')
state_dict[name] = v
return state_dict

# STEP 6: define model
def define_model(
sample_dataset,
tokenizer,
encoder_weights,
model_type='wordsat'):

if model_type == 'wordsat':
model=WordSAT(
dataset = sample_dataset,
n_input_images = 2,
label_key = 'caption',
tokenizer = tokenizer,
encoder_pretrained_weights = encoder_weights,
encoder_freeze_weights = True,
save_generated_caption = True
)
else:
model=SentSAT(
dataset = sample_dataset,
n_input_images = 2,
label_key = 'caption',
tokenizer = tokenizer,
encoder_pretrained_weights = encoder_weights,
encoder_freeze_weights = True,
save_generated_caption = True
)

return model

# STEP 7: run trainer
def run_trainer(output_path, train_dataloader, val_dataloader, model,n_epochs):
trainer = Trainer(
model=model,
output_path=output_path
)

trainer.train(
train_dataloader = train_dataloader,
val_dataloader = val_dataloader,
optimizer_params = {"lr": 1e-4},
weight_decay = 1e-5,
epochs = n_epochs,
monitor = 'Bleu_1'
)
return trainer

# STEP 8: evaluate
def evaluate(trainer,test_dataloader):
print(trainer.evaluate(test_dataloader))
return None

if __name__ == '__main__':
args = get_args()
seed_everything(42)
base_dataset = load_data(args.root)
sample_dataset = set_task(base_dataset)
train_dataloader,val_dataloader,test_dataloader = get_dataloaders(
sample_dataset)
tokenizer = get_tokenizer(args.root,sample_dataset)
encoder_weights = extract_encoder_state_dict(args.root,
args.encoder_chkpt_fname)

model = define_model(
sample_dataset,
tokenizer,
encoder_weights,
args.model_type)

trainer = run_trainer(
args.root,
train_dataloader,
val_dataloader,
model,
args.num_epochs)

print("\n===== Evaluating Test Data ======\n")
evaluate(trainer,test_dataloader)



182 changes: 181 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""
import pickle
from torchvision import transforms
from pyhealth.datasets import BaseImageCaptionDataset
Expand Down Expand Up @@ -46,6 +47,7 @@ def seed_everything(seed: int):
root = '/home/keshari2/ChestXrayReporting/IU_XRay/src/data'
sample_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay')
sample_dataset = sample_dataset.set_task(biview_multisent_fn)
transform = transforms.Compose([
transforms.RandomAffine(degrees=30),
transforms.Resize((512,512)),
Expand All @@ -55,6 +57,7 @@ def seed_everything(seed: int):
])
sample_dataset.set_transform(transform)
"""
"""
special_tokens = ['<pad>','<start>','<end>','<unk>']
tokenizer = Tokenizer(
Expand All @@ -65,6 +68,7 @@ def seed_everything(seed: int):
with open(root+'/pyhealth_tokenizer.pkl', 'wb') as f:
pickle.dump(tokenizer, f)
"""
"""
with open(root+'/pyhealth_tokenizer.pkl', 'rb') as f:
tokenizer = pickle.load(f)
#print(sample_dataset[0]['caption'])
Expand Down Expand Up @@ -105,4 +109,180 @@ def seed_everything(seed: int):
#max_grad_norm = 1,
epochs = 5,
monitor = 'Bleu_1'
)
)
"""
import os
import argparse
import pickle
import torch
from collections import OrderedDict
from torchvision import transforms
from pyhealth.datasets import BaseImageCaptionDataset
from pyhealth.tasks.xray_report_generation import biview_multisent_fn
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.tokenizer import Tokenizer
from pyhealth.models import WordSAT, SentSAT
from pyhealth.trainer import Trainer

def get_args():
parser = argparse.ArgumentParser()

parser.add_argument('--root', type=str, default=".")
parser.add_argument('--encoder-chkpt-fname', type=str, default=None)
parser.add_argument('--tokenizer-fname', type=str, default=None)
parser.add_argument('--num-epochs', type=int, default=1)
parser.add_argument('--model-type', type=str, default="wordsat")

args = parser.parse_args()
return args

def seed_everything(seed: int):
import random, os
import numpy as np
import torch

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True


# STEP 1: load data
def load_data(root):
base_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay')
return base_dataset

# STEP 2: set task
def set_task(base_dataset):
sample_dataset = base_dataset.set_task(biview_multisent_fn)
transform = transforms.Compose([
transforms.RandomAffine(degrees=30),
transforms.Resize((512,512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]),
])
sample_dataset.set_transform(transform)
return sample_dataset

# STEP 3: get dataloaders
def get_dataloaders(sample_dataset):
train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset,[0.8,0.1,0.1])

train_dataloader = get_dataloader(train_dataset,batch_size=8,shuffle=True)
val_dataloader = get_dataloader(val_dataset,batch_size=1,shuffle=False)
test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False)

return train_dataloader,val_dataloader,test_dataloader

# STEP 4: get tokenizer
def get_tokenizer(root,sample_dataset=None,tokenizer_fname=None):
if tokenizer_fname:
with open(os.path.join(root,tokenizer_fname), 'wb') as f:
tokenizer = pickle.load(f)
else:
# <pad> should always be first element in the list of special tokens
special_tokens = ['<pad>','<start>','<end>','<unk>']
tokenizer = Tokenizer(
sample_dataset.get_all_tokens(key='caption'),
special_tokens=special_tokens,
)
return tokenizer

# STEP 5: get encoder pretrained state dictionary
def extract_encoder_state_dict(root,chkpt_fname):
checkpoint = torch.load(os.path.join(root,chkpt_fname) )
state_dict = OrderedDict()
for k,v in checkpoint['state_dict'].items():
if 'classifier' in k: continue
if k[:7] == 'module.' :
name = k[7:]
else:
name = k

name = name.replace('classifier.0.','classifier.')
state_dict[name] = v
return state_dict

# STEP 6: define model
def define_model(
sample_dataset,
tokenizer,
encoder_weights,
model_type='wordsat'):

if model_type == 'wordsat':
model=WordSAT(
dataset = sample_dataset,
n_input_images = 2,
label_key = 'caption',
tokenizer = tokenizer,
encoder_pretrained_weights = encoder_weights,
encoder_freeze_weights = True,
save_generated_caption = True
)
else:
model=SentSAT(
dataset = sample_dataset,
n_input_images = 2,
label_key = 'caption',
tokenizer = tokenizer,
encoder_pretrained_weights = encoder_weights,
encoder_freeze_weights = True,
save_generated_caption = True
)

return model

# STEP 7: run trainer
def run_trainer(output_path, train_dataloader, val_dataloader, model,n_epochs):
trainer = Trainer(
model=model,
output_path=output_path
)

trainer.train(
train_dataloader = train_dataloader,
val_dataloader = val_dataloader,
optimizer_params = {"lr": 1e-4},
weight_decay = 1e-5,
epochs = n_epochs,
monitor = 'Bleu_1'
)
return trainer

# STEP 8: evaluate
def evaluate(trainer,test_dataloader):
print(trainer.evaluate(test_dataloader))
return None

if __name__ == '__main__':
args = get_args()
seed_everything(42)
base_dataset = load_data(args.root)
sample_dataset = set_task(base_dataset)
train_dataloader,val_dataloader,test_dataloader = get_dataloaders(
sample_dataset)
tokenizer = get_tokenizer(args.root,sample_dataset)
encoder_weights = extract_encoder_state_dict(args.root,
args.encoder_chkpt_fname)

model = define_model(
sample_dataset,
tokenizer,
encoder_weights,
args.model_type)

trainer = run_trainer(
args.root,
train_dataloader,
val_dataloader,
model,
args.num_epochs)

print("\n===== Evaluating Test Data ======\n")
evaluate(trainer,test_dataloader)

0 comments on commit 6e1db61

Please sign in to comment.