Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/xray report generation #144

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Prev Previous commit
Next Next commit
added exmple code for xray report generation
  • Loading branch information
samarthkeshari committed Apr 27, 2023
commit 6e1db61d2c7a28180c30805a180c8775e3cbef4e
178 changes: 178 additions & 0 deletions examples/xray_report_generation_sat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
import argparse
import pickle
import torch
from collections import OrderedDict
from torchvision import transforms
from pyhealth.datasets import BaseImageCaptionDataset
from pyhealth.tasks.xray_report_generation import biview_multisent_fn
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.tokenizer import Tokenizer
from pyhealth.models import WordSAT, SentSAT
from pyhealth.trainer import Trainer

def get_args():
parser = argparse.ArgumentParser()

parser.add_argument('--root', type=str, default=".")
parser.add_argument('--encoder-chkpt-fname', type=str, default=None)
parser.add_argument('--tokenizer-fname', type=str, default=None)
parser.add_argument('--num-epochs', type=int, default=1)
parser.add_argument('--model-type', type=str, default="wordsat")

args = parser.parse_args()
return args

def seed_everything(seed: int):
import random, os
import numpy as np
import torch

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True


# STEP 1: load data
def load_data(root):
base_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay')
return base_dataset

# STEP 2: set task
def set_task(base_dataset):
sample_dataset = base_dataset.set_task(biview_multisent_fn)
transform = transforms.Compose([
transforms.RandomAffine(degrees=30),
transforms.Resize((512,512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]),
])
sample_dataset.set_transform(transform)
return sample_dataset

# STEP 3: get dataloaders
def get_dataloaders(sample_dataset):
train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset,[0.8,0.1,0.1])

train_dataloader = get_dataloader(train_dataset,batch_size=8,shuffle=True)
val_dataloader = get_dataloader(val_dataset,batch_size=1,shuffle=False)
test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False)

return train_dataloader,val_dataloader,test_dataloader

# STEP 4: get tokenizer
def get_tokenizer(root,sample_dataset=None,tokenizer_fname=None):
if tokenizer_fname:
with open(os.path.join(root,tokenizer_fname), 'wb') as f:
tokenizer = pickle.load(f)
else:
# <pad> should always be first element in the list of special tokens
special_tokens = ['<pad>','<start>','<end>','<unk>']
tokenizer = Tokenizer(
sample_dataset.get_all_tokens(key='caption'),
special_tokens=special_tokens,
)
return tokenizer

# STEP 5: get encoder pretrained state dictionary
def extract_encoder_state_dict(root,chkpt_fname):
checkpoint = torch.load(os.path.join(root,chkpt_fname) )
state_dict = OrderedDict()
for k,v in checkpoint['state_dict'].items():
if 'classifier' in k: continue
if k[:7] == 'module.' :
name = k[7:]
else:
name = k

name = name.replace('classifier.0.','classifier.')
state_dict[name] = v
return state_dict

# STEP 6: define model
def define_model(
sample_dataset,
tokenizer,
encoder_weights,
model_type='wordsat'):

if model_type == 'wordsat':
model=WordSAT(
dataset = sample_dataset,
n_input_images = 2,
label_key = 'caption',
tokenizer = tokenizer,
encoder_pretrained_weights = encoder_weights,
encoder_freeze_weights = True,
save_generated_caption = True
)
else:
model=SentSAT(
dataset = sample_dataset,
n_input_images = 2,
label_key = 'caption',
tokenizer = tokenizer,
encoder_pretrained_weights = encoder_weights,
encoder_freeze_weights = True,
save_generated_caption = True
)

return model

# STEP 7: run trainer
def run_trainer(output_path, train_dataloader, val_dataloader, model,n_epochs):
trainer = Trainer(
model=model,
output_path=output_path
)

trainer.train(
train_dataloader = train_dataloader,
val_dataloader = val_dataloader,
optimizer_params = {"lr": 1e-4},
weight_decay = 1e-5,
epochs = n_epochs,
monitor = 'Bleu_1'
)
return trainer

# STEP 8: evaluate
def evaluate(trainer,test_dataloader):
print(trainer.evaluate(test_dataloader))
return None

if __name__ == '__main__':
args = get_args()
seed_everything(42)
base_dataset = load_data(args.root)
sample_dataset = set_task(base_dataset)
train_dataloader,val_dataloader,test_dataloader = get_dataloaders(
sample_dataset)
tokenizer = get_tokenizer(args.root,sample_dataset)
encoder_weights = extract_encoder_state_dict(args.root,
args.encoder_chkpt_fname)

model = define_model(
sample_dataset,
tokenizer,
encoder_weights,
args.model_type)

trainer = run_trainer(
args.root,
train_dataloader,
val_dataloader,
model,
args.num_epochs)

print("\n===== Evaluating Test Data ======\n")
evaluate(trainer,test_dataloader)



182 changes: 181 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""
import pickle
from torchvision import transforms
from pyhealth.datasets import BaseImageCaptionDataset
Expand Down Expand Up @@ -46,6 +47,7 @@ def seed_everything(seed: int):
root = '/home/keshari2/ChestXrayReporting/IU_XRay/src/data'
sample_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay')
sample_dataset = sample_dataset.set_task(biview_multisent_fn)

transform = transforms.Compose([
transforms.RandomAffine(degrees=30),
transforms.Resize((512,512)),
Expand All @@ -55,6 +57,7 @@ def seed_everything(seed: int):
])
sample_dataset.set_transform(transform)

"""
"""
special_tokens = ['<pad>','<start>','<end>','<unk>']
tokenizer = Tokenizer(
Expand All @@ -65,6 +68,7 @@ def seed_everything(seed: int):
with open(root+'/pyhealth_tokenizer.pkl', 'wb') as f:
pickle.dump(tokenizer, f)
"""
"""
with open(root+'/pyhealth_tokenizer.pkl', 'rb') as f:
tokenizer = pickle.load(f)
#print(sample_dataset[0]['caption'])
Expand Down Expand Up @@ -105,4 +109,180 @@ def seed_everything(seed: int):
#max_grad_norm = 1,
epochs = 5,
monitor = 'Bleu_1'
)
)
"""
import os
import argparse
import pickle
import torch
from collections import OrderedDict
from torchvision import transforms
from pyhealth.datasets import BaseImageCaptionDataset
from pyhealth.tasks.xray_report_generation import biview_multisent_fn
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.tokenizer import Tokenizer
from pyhealth.models import WordSAT, SentSAT
from pyhealth.trainer import Trainer

def get_args():
parser = argparse.ArgumentParser()

parser.add_argument('--root', type=str, default=".")
parser.add_argument('--encoder-chkpt-fname', type=str, default=None)
parser.add_argument('--tokenizer-fname', type=str, default=None)
parser.add_argument('--num-epochs', type=int, default=1)
parser.add_argument('--model-type', type=str, default="wordsat")

args = parser.parse_args()
return args

def seed_everything(seed: int):
import random, os
import numpy as np
import torch

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True


# STEP 1: load data
def load_data(root):
base_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay')
return base_dataset

# STEP 2: set task
def set_task(base_dataset):
sample_dataset = base_dataset.set_task(biview_multisent_fn)
transform = transforms.Compose([
transforms.RandomAffine(degrees=30),
transforms.Resize((512,512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]),
])
sample_dataset.set_transform(transform)
return sample_dataset

# STEP 3: get dataloaders
def get_dataloaders(sample_dataset):
train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset,[0.8,0.1,0.1])

train_dataloader = get_dataloader(train_dataset,batch_size=8,shuffle=True)
val_dataloader = get_dataloader(val_dataset,batch_size=1,shuffle=False)
test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False)

return train_dataloader,val_dataloader,test_dataloader

# STEP 4: get tokenizer
def get_tokenizer(root,sample_dataset=None,tokenizer_fname=None):
if tokenizer_fname:
with open(os.path.join(root,tokenizer_fname), 'wb') as f:
tokenizer = pickle.load(f)
else:
# <pad> should always be first element in the list of special tokens
special_tokens = ['<pad>','<start>','<end>','<unk>']
tokenizer = Tokenizer(
sample_dataset.get_all_tokens(key='caption'),
special_tokens=special_tokens,
)
return tokenizer

# STEP 5: get encoder pretrained state dictionary
def extract_encoder_state_dict(root,chkpt_fname):
checkpoint = torch.load(os.path.join(root,chkpt_fname) )
state_dict = OrderedDict()
for k,v in checkpoint['state_dict'].items():
if 'classifier' in k: continue
if k[:7] == 'module.' :
name = k[7:]
else:
name = k

name = name.replace('classifier.0.','classifier.')
state_dict[name] = v
return state_dict

# STEP 6: define model
def define_model(
sample_dataset,
tokenizer,
encoder_weights,
model_type='wordsat'):

if model_type == 'wordsat':
model=WordSAT(
dataset = sample_dataset,
n_input_images = 2,
label_key = 'caption',
tokenizer = tokenizer,
encoder_pretrained_weights = encoder_weights,
encoder_freeze_weights = True,
save_generated_caption = True
)
else:
model=SentSAT(
dataset = sample_dataset,
n_input_images = 2,
label_key = 'caption',
tokenizer = tokenizer,
encoder_pretrained_weights = encoder_weights,
encoder_freeze_weights = True,
save_generated_caption = True
)

return model

# STEP 7: run trainer
def run_trainer(output_path, train_dataloader, val_dataloader, model,n_epochs):
trainer = Trainer(
model=model,
output_path=output_path
)

trainer.train(
train_dataloader = train_dataloader,
val_dataloader = val_dataloader,
optimizer_params = {"lr": 1e-4},
weight_decay = 1e-5,
epochs = n_epochs,
monitor = 'Bleu_1'
)
return trainer

# STEP 8: evaluate
def evaluate(trainer,test_dataloader):
print(trainer.evaluate(test_dataloader))
return None

if __name__ == '__main__':
args = get_args()
seed_everything(42)
base_dataset = load_data(args.root)
sample_dataset = set_task(base_dataset)
train_dataloader,val_dataloader,test_dataloader = get_dataloaders(
sample_dataset)
tokenizer = get_tokenizer(args.root,sample_dataset)
encoder_weights = extract_encoder_state_dict(args.root,
args.encoder_chkpt_fname)

model = define_model(
sample_dataset,
tokenizer,
encoder_weights,
args.model_type)

trainer = run_trainer(
args.root,
train_dataloader,
val_dataloader,
model,
args.num_epochs)

print("\n===== Evaluating Test Data ======\n")
evaluate(trainer,test_dataloader)