-
Notifications
You must be signed in to change notification settings - Fork 8
/
utils.py
29 lines (24 loc) · 744 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
##################################################
# Imports
##################################################
import pytorch_lightning as pl
import os
def get_callbacks(args):
callbacks = []
# Model checkpoint
model_checkpoint_clbk = pl.callbacks.model_checkpoint.ModelCheckpoint(
dirpath=None,
filename='best',
monitor='validation_acc',
save_last=True,
mode='max',
)
model_checkpoint_clbk.CHECKPOINT_NAME_LAST = '{epoch}-{step}'
callbacks += [model_checkpoint_clbk]
return callbacks
def get_logger(args):
tb_logger = pl.loggers.tensorboard.TensorBoardLogger(
save_dir=os.path.join(os.getcwd(), 'tmp'),
name=args.dataset,
)
return tb_logger