""" import pickle from torchvision import transforms from pyhealth.datasets import BaseImageCaptionDataset from pyhealth.tasks.xray_report_generation import biview_multisent_fn from pyhealth.datasets import split_by_patient, get_dataloader from pyhealth.tokenizer import Tokenizer from pyhealth.models import WordSAT, SentSAT from pyhealth.trainer import Trainer from pyhealth.datasets.utils import list_nested_levels, flatten_list import torch from collections import OrderedDict def extract_state_dict(chkpt_pth): checkpoint = torch.load(chkpt_pth + 'model_ones_3epoch_densenet.tar') new_state_dict = OrderedDict() for k,v in checkpoint['state_dict'].items(): if 'classifier' in k: continue if k[:7] == 'module.' : name = k[7:] else: name = k name = name.replace('classifier.0.','classifier.') new_state_dict[name] = v return new_state_dict chkpt_pth = '/home/keshari2/ChestXrayReporting/IU_XRay/src/models/pretrained/' state_dict = extract_state_dict(chkpt_pth) ##### def seed_everything(seed: int): import random, os import numpy as np import torch random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True seed_everything(42) root = '/home/keshari2/ChestXrayReporting/IU_XRay/src/data' sample_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay') sample_dataset = sample_dataset.set_task(biview_multisent_fn) transform = transforms.Compose([ transforms.RandomAffine(degrees=30), transforms.Resize((512,512)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ]) sample_dataset.set_transform(transform) """ """ special_tokens = ['','','',''] tokenizer = Tokenizer( sample_dataset.get_all_tokens(key='caption'), special_tokens=special_tokens, ) with open(root+'/pyhealth_tokenizer.pkl', 'wb') as f: pickle.dump(tokenizer, f) """ """ with open(root+'/pyhealth_tokenizer.pkl', 'rb') as f: tokenizer = pickle.load(f) #print(sample_dataset[0]['caption']) train_dataset, val_dataset, test_dataset = split_by_patient( sample_dataset,[0.8,0.1,0.1] ) train_dataloader = get_dataloader(train_dataset,batch_size=16,shuffle=True) val_dataloader = get_dataloader(val_dataset,batch_size=2,shuffle=False) test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False) model=SentSAT( dataset=sample_dataset, n_input_images = 2, label_key='caption', tokenizer=tokenizer, encoder_pretrained_weights=state_dict, save_generated_caption = True ) #model.eval() #data = next(iter(val_dataloader)) #print(model(**data)) output_path = '/home/keshari2/ChestXrayReporting/IU_XRay/src/output/pyhealth' ckpt_path = '/home/keshari2/ChestXrayReporting/IU_XRay/src/output/pyhealth/20230424-114914/best.ckpt' trainer = Trainer( model=model, output_path=output_path, checkpoint_path = ckpt_path ) trainer.train( train_dataloader = train_dataloader, val_dataloader = val_dataloader, optimizer_params = {"lr": 1e-4}, weight_decay = 1e-5, #max_grad_norm = 1, epochs = 5, monitor = 'Bleu_1' ) """ import os import argparse import pickle import torch from collections import OrderedDict from torchvision import transforms from pyhealth.datasets import BaseImageCaptionDataset from pyhealth.tasks.xray_report_generation import biview_multisent_fn from pyhealth.datasets import split_by_patient, get_dataloader from pyhealth.tokenizer import Tokenizer from pyhealth.models import WordSAT, SentSAT from pyhealth.trainer import Trainer def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--root', type=str, default=".") parser.add_argument('--encoder-chkpt-fname', type=str, default=None) parser.add_argument('--tokenizer-fname', type=str, default=None) parser.add_argument('--num-epochs', type=int, default=1) parser.add_argument('--model-type', type=str, default="wordsat") args = parser.parse_args() return args def seed_everything(seed: int): import random, os import numpy as np import torch random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True # STEP 1: load data def load_data(root): base_dataset = BaseImageCaptionDataset(root=root,dataset_name='IU_XRay') return base_dataset # STEP 2: set task def set_task(base_dataset): sample_dataset = base_dataset.set_task(biview_multisent_fn) transform = transforms.Compose([ transforms.RandomAffine(degrees=30), transforms.Resize((512,512)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ]) sample_dataset.set_transform(transform) return sample_dataset # STEP 3: get dataloaders def get_dataloaders(sample_dataset): train_dataset, val_dataset, test_dataset = split_by_patient( sample_dataset,[0.8,0.1,0.1]) train_dataloader = get_dataloader(train_dataset,batch_size=8,shuffle=True) val_dataloader = get_dataloader(val_dataset,batch_size=1,shuffle=False) test_dataloader = get_dataloader(test_dataset,batch_size=1,shuffle=False) return train_dataloader,val_dataloader,test_dataloader # STEP 4: get tokenizer def get_tokenizer(root,sample_dataset=None,tokenizer_fname=None): if tokenizer_fname: with open(os.path.join(root,tokenizer_fname), 'wb') as f: tokenizer = pickle.load(f) else: # should always be first element in the list of special tokens special_tokens = ['','','',''] tokenizer = Tokenizer( sample_dataset.get_all_tokens(key='caption'), special_tokens=special_tokens, ) return tokenizer # STEP 5: get encoder pretrained state dictionary def extract_encoder_state_dict(root,chkpt_fname): checkpoint = torch.load(os.path.join(root,chkpt_fname) ) state_dict = OrderedDict() for k,v in checkpoint['state_dict'].items(): if 'classifier' in k: continue if k[:7] == 'module.' : name = k[7:] else: name = k name = name.replace('classifier.0.','classifier.') state_dict[name] = v return state_dict # STEP 6: define model def define_model( sample_dataset, tokenizer, encoder_weights, model_type='wordsat'): if model_type == 'wordsat': model=WordSAT( dataset = sample_dataset, n_input_images = 2, label_key = 'caption', tokenizer = tokenizer, encoder_pretrained_weights = encoder_weights, encoder_freeze_weights = True, save_generated_caption = True ) else: model=SentSAT( dataset = sample_dataset, n_input_images = 2, label_key = 'caption', tokenizer = tokenizer, encoder_pretrained_weights = encoder_weights, encoder_freeze_weights = True, save_generated_caption = True ) return model # STEP 7: run trainer def run_trainer(output_path, train_dataloader, val_dataloader, model,n_epochs): trainer = Trainer( model=model, output_path=output_path ) trainer.train( train_dataloader = train_dataloader, val_dataloader = val_dataloader, optimizer_params = {"lr": 1e-4}, weight_decay = 1e-5, epochs = n_epochs, monitor = 'Bleu_1' ) return trainer # STEP 8: evaluate def evaluate(trainer,test_dataloader): print(trainer.evaluate(test_dataloader)) return None if __name__ == '__main__': args = get_args() seed_everything(42) base_dataset = load_data(args.root) sample_dataset = set_task(base_dataset) train_dataloader,val_dataloader,test_dataloader = get_dataloaders( sample_dataset) tokenizer = get_tokenizer(args.root,sample_dataset) encoder_weights = extract_encoder_state_dict(args.root, args.encoder_chkpt_fname) model = define_model( sample_dataset, tokenizer, encoder_weights, args.model_type) trainer = run_trainer( args.root, train_dataloader, val_dataloader, model, args.num_epochs) print("\n===== Evaluating Test Data ======\n") evaluate(trainer,test_dataloader)