From bf28e670e011cb174a32167e6b282890591bbc61 Mon Sep 17 00:00:00 2001 From: jmisilo Date: Tue, 15 Nov 2022 21:11:16 +0100 Subject: [PATCH] Format code and delete unnecessary function in training.py --- src/model/loops.py | 14 +++++++++++++- src/training.py | 30 ++++++++++++++---------------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/model/loops.py b/src/model/loops.py index 0c6d4b0..d9c3e3a 100644 --- a/src/model/loops.py +++ b/src/model/loops.py @@ -13,7 +13,19 @@ from tqdm import tqdm class Trainer: - def __init__(self, model, optimizer, scaler, scheduler, train_loader, valid_loader, test_dataset, test_path, ckp_path, device): + def __init__( + self, + model, + optimizer, + scaler, + scheduler, + train_loader, + valid_loader, + test_dataset='./data', + test_path='', + ckp_path='', + device='cpu' + ): self.model = model self.optimizer = optimizer self.scaler = scaler diff --git a/src/training.py b/src/training.py index b452b30..9ae33ce 100644 --- a/src/training.py +++ b/src/training.py @@ -36,19 +36,10 @@ torch.cuda.manual_seed(config.seed) torch.backends.cudnn.deterministic = True -def train(config, ckp_name=''): +if __name__ == '__main__': + is_cuda = torch.cuda.is_available() device = torch.device('cuda' if is_cuda else 'cpu') - - model = Net( - ep_len=config.ep_len, - num_layers=config.num_layers, - n_heads=config.n_heads, - forward_expansion=config.forward_expansion, - dropout=config.dropout, - max_len=config.max_len, - device=device - ) dataset = MiniFlickrDataset(os.path.join('data', 'processed', 'dataset.pkl')) @@ -74,6 +65,16 @@ def train(config, ckp_name=''): pin_memory=is_cuda ) + model = Net( + ep_len=config.ep_len, + num_layers=config.num_layers, + n_heads=config.n_heads, + forward_expansion=config.forward_expansion, + dropout=config.dropout, + max_len=config.max_len, + device=device + ) + optimizer = optim.Adam(model.parameters(), lr=config.lr) warmup = LRWarmup(epochs=config.epochs, max_lr=config.lr, k=config.k) @@ -81,7 +82,7 @@ def train(config, ckp_name=''): scheduler = optim.lr_scheduler.LambdaLR(optimizer, warmup.lr_warmup) scaler = torch.cuda.amp.GradScaler() - ckp_path = os.path.join(config.weights_dir, ckp_name) + ckp_path = os.path.join(config.weights_dir, args.checkpoint_name) trainer = Trainer( model=model, @@ -118,7 +119,4 @@ def train(config, ckp_name=''): os.makedirs(config.weights_dir) if (epoch + 1) % 50 == 0: - trainer.save_ckp(os.path.join(config.weights_dir, f'epoch_{epoch + 1}.pt')) - -if __name__ == '__main__': - train(config, args.checkpoint_name) \ No newline at end of file + trainer.save_ckp(os.path.join(config.weights_dir, f'epoch_{epoch + 1}.pt')) \ No newline at end of file