Skip to content

Commit

Permalink
Format code and delete unnecessary function in training.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jmisilo committed Nov 15, 2022
1 parent ec71006 commit bf28e67
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
14 changes: 13 additions & 1 deletion src/model/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,19 @@
from tqdm import tqdm

class Trainer:
def __init__(self, model, optimizer, scaler, scheduler, train_loader, valid_loader, test_dataset, test_path, ckp_path, device):
def __init__(
self,
model,
optimizer,
scaler,
scheduler,
train_loader,
valid_loader,
test_dataset='./data',
test_path='',
ckp_path='',
device='cpu'
):
self.model = model
self.optimizer = optimizer
self.scaler = scaler
Expand Down
30 changes: 14 additions & 16 deletions src/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,10 @@
torch.cuda.manual_seed(config.seed)
torch.backends.cudnn.deterministic = True

def train(config, ckp_name=''):
if __name__ == '__main__':

is_cuda = torch.cuda.is_available()
device = torch.device('cuda' if is_cuda else 'cpu')

model = Net(
ep_len=config.ep_len,
num_layers=config.num_layers,
n_heads=config.n_heads,
forward_expansion=config.forward_expansion,
dropout=config.dropout,
max_len=config.max_len,
device=device
)

dataset = MiniFlickrDataset(os.path.join('data', 'processed', 'dataset.pkl'))

Expand All @@ -74,14 +65,24 @@ def train(config, ckp_name=''):
pin_memory=is_cuda
)

model = Net(
ep_len=config.ep_len,
num_layers=config.num_layers,
n_heads=config.n_heads,
forward_expansion=config.forward_expansion,
dropout=config.dropout,
max_len=config.max_len,
device=device
)

optimizer = optim.Adam(model.parameters(), lr=config.lr)

warmup = LRWarmup(epochs=config.epochs, max_lr=config.lr, k=config.k)

scheduler = optim.lr_scheduler.LambdaLR(optimizer, warmup.lr_warmup)
scaler = torch.cuda.amp.GradScaler()

ckp_path = os.path.join(config.weights_dir, ckp_name)
ckp_path = os.path.join(config.weights_dir, args.checkpoint_name)

trainer = Trainer(
model=model,
Expand Down Expand Up @@ -118,7 +119,4 @@ def train(config, ckp_name=''):
os.makedirs(config.weights_dir)

if (epoch + 1) % 50 == 0:
trainer.save_ckp(os.path.join(config.weights_dir, f'epoch_{epoch + 1}.pt'))

if __name__ == '__main__':
train(config, args.checkpoint_name)
trainer.save_ckp(os.path.join(config.weights_dir, f'epoch_{epoch + 1}.pt'))

0 comments on commit bf28e67

Please sign in to comment.