# Training and validating medgaze from typing import Optional, List from timeit import default_timer as timer import argparse from datetime import datetime import os from os.path import join import json from tqdm import tqdm import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.utils.data import DataLoader import warnings warnings.filterwarnings("ignore") args={ 'head_lr':1e-6, 'tail_lr':1e-4, 'belly_lr':2e-6, 'dataset_dir': "/dataset", 'train_file':'train_ref_128_dur.json', # Train file name 'valid_file':'val_ref_128_dur.json', #test file name 'img_ftrs_dir': "/image_features_text_qformer_llm_exp_128_full_eg_reflacx/", # Download the image features from the link provided 'im_h':8, 'im_w':8, 'patch_size':16, 'seed':42, 'batch_size':32, 'epochs':1000, 'max_len':50, 'num_encoder':6, 'num_decoder':6, 'hidden_dim':1408, 'nhead':8, 'img_hidden_dim':2048, 'lm_hidden_dim':768, 'encoder_dropout':0.1, 'decoder_dropout':0.2, 'cls_dropout':0.4, 'retraining':False, 'model_root':'/medgaze_qformer_llm_using_rest_feaex_8x8.py_128_128/', # directory to save the model 'cuda':3, 'num_workers':6 } PRETRAINED_MODEL_CONFIG_DICT = { "pretrain_opt2.7b": "configs/models/blip2/blip2_pretrain_opt2.7b.yaml", "pretrain_opt6.7b": "configs/models/blip2/blip2_pretrain_opt6.7b.yaml", "caption_coco_opt2.7b": "configs/models/blip2/blip2_caption_opt2.7b.yaml", "caption_coco_opt6.7b": "configs/models/blip2/blip2_caption_opt6.7b.yaml", } # In[5]: # In[6]: # Creating the reports x=np.load(open( '/embeddings_text_egd_ref.npy', mode='rb'), allow_pickle = True) # In[7]: uy=[] import json path='/full_egd_ref_128_dur.json' def js_r(filename: str): with open(filename) as f_in: return json.load(f_in) fulld=js_r(path) new_dict = {item['name']:item for item in fulld} #tasks=[] for i in x.item().items(): #print(i[0]) x.item()[i[0]]=new_dict[i[0]+'.jpg']['task'] # In[8]: # MedGaze model import torch import torch.nn.functional as F from torch import nn, Tensor torch.autograd.set_detect_anomaly(True) def seed_everything(seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def fixations2seq(fixations, max_len): processed_fixs = [] for fix in fixations: processed_fixs.append({'tgt_seq_y': torch.tensor(np.array(fix['Y'])[:max_len]), 'tgt_seq_x': torch.tensor(np.array(fix['X'])[:max_len]), 'tgt_seq_t': torch.tensor(np.array(fix['T'])[:max_len]), 'task': fix['task'], 'img_name':fix['name']}) return processed_fixs def save_model_train(epoch, args, model, SlowOpt, MidOpt, FastOpt, model_dir, model_name): state = { "epoch": epoch, "args": str(args), "model": model.module.state_dict() if hasattr(model, "module") else model.state_dict(), "optim_slow": SlowOpt.state_dict(), "optim_mid": MidOpt.state_dict(), "optim_fast": FastOpt.state_dict(), } torch.save(state, join(model_dir, model_name+'_'+str(epoch)+'.pkg')) torch.autograd.set_detect_anomaly(True) def train(epoch, args, model, SlowOpt, MidOpt, FastOpt, loss_fn_token, loss_fn_y, loss_fn_x, loss_fn_t, train_dataloader, model_dir, model_name, device = 'cuda:1', im_h=20, im_w=32, patch_size=16): model.train() token_losses = 0 reg_losses = 0 t_losses = 0 with tqdm(train_dataloader, unit="batch") as tepoch: minibatch = 0 for batch_imgs, batch_tgt, batch_tgt_padding_mask, batch_tasks, batch_firstfix in tepoch: print('starting with some ') print(batch_imgs.shape,batch_tgt.shape,batch_firstfix.shape,len(batch_tasks)) out_token, out_y, out_x, out_t = model(src = batch_imgs, tgt = batch_firstfix, task = batch_tasks) out_y, out_x = torch.clamp(out_y, min=0, max=im_h * patch_size - 2), torch.clamp(out_x, min=0, max=im_w * patch_size - 2) SlowOpt.zero_grad() MidOpt.zero_grad() FastOpt.zero_grad() tgt_out = batch_tgt.to(device) batch_tgt_padding_mask = batch_tgt_padding_mask.to(device) token_gt = batch_tgt_padding_mask.long() fixation_mask = torch.logical_not(batch_tgt_padding_mask).float() #predict padding or valid fixation token_loss = loss_fn_token(out_token.permute(1,2,0), token_gt) out_y = out_y.squeeze(-1).permute(1,0) * fixation_mask out_x = out_x.squeeze(-1).permute(1,0) * fixation_mask out_t = out_t.squeeze(-1).permute(1,0) * fixation_mask #calculate regression L1 losses for only valid ground truth fixations reg_loss = (loss_fn_y(out_y.float(), tgt_out[:, :, 0] * fixation_mask).sum(-1)/fixation_mask.sum(-1) + loss_fn_x(out_x.float(), tgt_out[:, :, 1]*fixation_mask).sum(-1)/fixation_mask.sum(-1)).mean() t_loss = (loss_fn_t(out_t.float(), tgt_out[:, :, 2]*fixation_mask).sum(-1)/fixation_mask.sum(-1)).mean() loss = token_loss + reg_loss + t_loss loss.backward() token_losses += token_loss.item() reg_losses += reg_loss.item() t_losses += t_loss.item() SlowOpt.step() MidOpt.step() FastOpt.step() minibatch += 1. tepoch.set_postfix(token_loss=token_losses/minibatch, reg_loss=reg_losses/minibatch, t_loss=t_losses/minibatch) checking=epoch%30 if int(checking)==0: save_model_train(epoch, args, model, SlowOpt, MidOpt, FastOpt, model_dir, model_name) return token_losses / len(train_dataloader), reg_losses / len(train_dataloader), t_losses / len(train_dataloader) def evaluate(model, loss_fn_token, loss_fn_y, loss_fn_x, loss_fn_t, valid_dataloader, device = 'cuda:0', im_h=20, im_w=32, patch_size=16): model.eval() token_losses = 0 reg_losses = 0 t_losses = 0 with tqdm(valid_dataloader, unit="batch") as tepoch: minibatch = 0 for batch_imgs, batch_tgt, batch_tgt_padding_mask, batch_tasks, batch_firstfix in tepoch: print(batch_imgs.shape,batch_firstfix.shape) with torch.no_grad(): out_token, out_y, out_x,out_t = model(src = batch_imgs, tgt = batch_firstfix, task = batch_tasks) out_y, out_x = torch.clamp(out_y, min=0, max=im_h *patch_size -2), torch.clamp(out_x, min=0, max=im_w *patch_size -2) tgt_out = batch_tgt.to(device) batch_tgt_padding_mask = batch_tgt_padding_mask.to(device) token_gt = batch_tgt_padding_mask.long() fixation_mask = torch.logical_not(batch_tgt_padding_mask).float() token_loss = loss_fn_token(out_token.permute(1,2,0), token_gt) out_y = out_y.squeeze(-1).permute(1,0) * fixation_mask out_x = out_x.squeeze(-1).permute(1,0) * fixation_mask out_t = out_t.squeeze(-1).permute(1,0) * fixation_mask reg_loss = (loss_fn_y(out_y.float(), tgt_out[:, :, 0] * fixation_mask).sum(-1)/fixation_mask.sum(-1) + loss_fn_x(out_x.float(), tgt_out[:, :, 1]*fixation_mask).sum(-1)/fixation_mask.sum(-1)).mean() t_loss = (loss_fn_t(out_t.float(), tgt_out[:, :, 2]*fixation_mask).sum(-1)/fixation_mask.sum(-1)).mean() token_losses += token_loss.item() reg_losses += reg_loss.item() t_losses += t_loss.item() minibatch += 1. tepoch.set_postfix(token_loss=token_losses/minibatch, reg_loss=reg_losses/minibatch, t_loss=t_losses/minibatch) return token_losses / len(valid_dataloader), reg_losses / len(valid_dataloader), t_losses/len(valid_dataloader) def main(args): seed_everything(args.seed) device = torch.device('cuda:{}'.format(args.cuda)) device_id = args.cuda retraining = args.retraining last_checkpoint = args.last_checkpoint if retraining: model_dir = '/'.join(args.last_checkpoint.split('/')[:-1]) #args = args logfile = 'full128_medgaze_qformer_llm_using_rest_feaex_8x8_128_128_' +'retraining' +'.txt' args.cuda = device_id else: timenow = datetime.now().strftime("%d-%m-%Y-%H-%M-%S") logfile = 'full128_medgaze_qformer_llm_using_rest_feaex_8x8_128_128_' + timenow + '.txt' model_dir = join(args.model_root, 'train_full' + timenow) os.mkdir(model_dir) open(logfile, 'w').close() with open(logfile, "a") as myfile: myfile.write(str(vars(args)) + '\n\n') myfile.close() print(str(args) + '\n\n') with open(join(model_dir, 'config.json'), "w") as outfile: json.dump(str(args), outfile) outfile.close() model_name = 'medgaze_'+str(args.num_encoder)+'E_'+str(args.num_decoder)+'D_'+str(args.batch_size)+'_'+str(args.hidden_dim)+'d' dataset_root = args.dataset_dir train_file = args.train_file valid_file = args.valid_file with open(join(dataset_root, train_file)) as json_file: fixations_train = json.load(json_file) with open(join(dataset_root, valid_file)) as json_file: fixations_valid = json.load(json_file) seq_train = fixations2seq(fixations =fixations_train, max_len = args.max_len) seq_valid = fixations2seq(fixations = fixations_valid, max_len = args.max_len) train_dataset = fixation_dataset(seq_train, img_ftrs_dir = args.img_ftrs_dir) valid_dataset = fixation_dataset(seq_valid, img_ftrs_dir = args.img_ftrs_dir) #target embeddings embedding_dict = x.item() collate_fn = COCOSearch18Collator(embedding_dict, args.max_len, args.im_h, args.im_w, args.patch_size) train_dataloader = DataLoader(train_dataset, batch_size = args.batch_size, shuffle=True, num_workers=6, collate_fn = collate_fn) valid_dataloader = DataLoader(valid_dataset, batch_size = args.batch_size, shuffle=False, num_workers=6, collate_fn = collate_fn) transformer = Transformer(num_encoder_layers=args.num_encoder, nhead = args.nhead, d_model = args.hidden_dim, num_decoder_layers=args.num_decoder, encoder_dropout = args.encoder_dropout, decoder_dropout = args.decoder_dropout, dim_feedforward = args.hidden_dim, img_hidden_dim = args.img_hidden_dim, lm_dmodel = args.lm_hidden_dim, device = device).to(device) model = medgaze(transformer, spatial_dim = (args.im_h, args.im_w), dropout=args.cls_dropout, max_len = args.max_len, device = device).to(device) loss_fn_token = torch.nn.NLLLoss() loss_fn_y = nn.L1Loss(reduction='none') loss_fn_x = nn.L1Loss(reduction='none') loss_fn_t = nn.L1Loss(reduction='none') #Disjoint optimization head_params = list(model.transformer.encoder.parameters()) + list(model.token_predictor.parameters()) SlowOpt = torch.optim.AdamW( head_params, lr=args.head_lr, betas=(0.9, 0.98), eps=1e-9, weight_decay=1e-4) belly_params = list(model.generator_t_mu.parameters()) + list(model.generator_t_logvar.parameters()) MidOpt = torch.optim.AdamW(belly_params, lr=args.belly_lr, betas=(0.9, 0.98), eps=1e-9, weight_decay=1e-4) tail_params = list(model.transformer.decoder.parameters()) + list(model.generator_y_mu.parameters()) + list(model.generator_x_mu.parameters()) + list(model.generator_y_logvar.parameters()) + list(model.generator_x_logvar.parameters()) + list(model.querypos_embed.parameters()) + list(model.firstfix_linear.parameters()) FastOpt = torch.optim.AdamW(tail_params, lr=args.tail_lr, betas=(0.9, 0.98), eps=1e-9, weight_decay=1e-4) start_epoch = 1 retraining=0 if retraining: checkpoint = torch.load('/medgaze_qformer_llm_using_rest_feaex_8x8.py_128_128/train_27-02-2024-23-11-14/medgaze_6E_6D_128_1408d_85.pkg', map_location=device) model.load_state_dict(checkpoint['model']) #SlowOpt.load_state_dict(checkpoint['optim_slow']) #MidOpt.load_state_dict(checkpoint['optim_mid']) #FastOpt.load_state_dict(checkpoint['optim_fast']) start_epoch = checkpoint['epoch'] + 1 print("Retraining from", start_epoch) for epoch in range(start_epoch, args.epochs+1): start_time = timer() train_token_loss, train_reg_loss, train_t_loss = train(epoch = epoch, args = args, model = model, SlowOpt = SlowOpt, FastOpt = FastOpt, MidOpt = MidOpt, loss_fn_token = loss_fn_token, loss_fn_y = loss_fn_y, loss_fn_x = loss_fn_x, loss_fn_t = loss_fn_t, train_dataloader = train_dataloader, model_dir = model_dir, model_name = model_name, device = device) end_time = timer() valid_token_loss, valid_reg_loss, valid_t_loss = evaluate(model = model, loss_fn_token = loss_fn_token, loss_fn_y = loss_fn_y, loss_fn_x = loss_fn_x, loss_fn_t=loss_fn_t, valid_dataloader = valid_dataloader, device = device) output_str = f"Epoch: {epoch}, Train token loss: {train_token_loss:.3f}, Train reg loss: {train_reg_loss:.3f}, Train T loss: {train_t_loss:.3f}, Val token loss: {valid_token_loss:.3f}, Val reg loss: {valid_reg_loss:.3f}, Valid T loss: {valid_t_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s, Saved to {model_dir+'/'+model_name}\n" print(output_str) with open(logfile, "a") as myfile: myfile.write(output_str) myfile.close() import pandas as pd args=pd.Series(args) main(args)