-
Notifications
You must be signed in to change notification settings - Fork 4
/
mirdnn_fit.py
executable file
·97 lines (84 loc) · 3.21 KB
/
mirdnn_fit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/python3
import os.path
import sys
import time
import torch as tr
import numpy as np
import random as rn
import torch.utils.data as dt
from sklearn.metrics import precision_recall_curve, auc
from src.fold_dataset import FoldDataset
from src.model import mirDNN
from src.parameters import ParameterParser
from src.sampler import ImbalancedDatasetSampler
from src.logger import Logger
def main(argv):
pp = ParameterParser(argv)
if not pp.random_seed is None:
rn.seed(pp.random_seed)
np.random.seed(pp.random_seed)
tr.manual_seed(pp.random_seed)
if pp.device.type == 'cuda':
if not pp.random_seed is None:
tr.backends.cudnn.deterministic = True
tr.backends.cudnn.benchmark = False
else:
tr.backends.cudnn.deterministic = False
tr.backends.cudnn.benchmark = True
dataset = FoldDataset(pp.input_files, pp.seq_len)
valid_size = int(pp.valid_prop * len(dataset))
train, valid = dt.random_split(dataset, (len(dataset)-valid_size, valid_size))
train_loader = None
if pp.upsample:
train_sampler = ImbalancedDatasetSampler(train,
max_imbalance = 1.0,
num_samples = 8 * pp.batch_size)
train_loader = dt.DataLoader(train,
batch_size=pp.batch_size,
shuffle=True,
sampler=train_sampler,
pin_memory=True)
else:
train_loader = dt.DataLoader(train,
batch_size=pp.batch_size,
shuffle=True,
pin_memory=True)
valid_loader = dt.DataLoader(valid,
batch_size=pp.batch_size,
pin_memory=True)
model = mirDNN(pp)
model.train()
log = Logger(pp.logfile)
if not pp.model_file is None and os.path.isfile(pp.model_file):
model.load(pp.model_file)
log.write('epoch\ttrainLoss\tvalidAUC\tlast_imp\n')
epoch = 0
train_loss = 100
valid_auc = 0
best_valid_auc = 0
last_improvement = 0
while last_improvement < pp.early_stop:
nbatch = 0
for x, v, y in train_loader:
new_loss = model.train_step(x, v, y)
train_loss = 0.99 * train_loss + 0.01 * new_loss
nbatch += 1
if nbatch >= 1000: continue
preds, labels = tr.Tensor([]), tr.Tensor([])
for x, v, y in valid_loader:
z = model(x, v).cpu().detach()
preds = tr.cat([preds, z.squeeze()])
labels = tr.cat([labels, y.squeeze()])
pr, rc, _ = precision_recall_curve(labels, preds)
valid_auc = 10 * auc(rc, pr) + 0.9 * valid_auc
last_improvement += 1
if valid_auc > best_valid_auc:
best_valid_auc = valid_auc
last_improvement = 0
model.save(pp.model_file)
log.write('%d\t%.4f\t%.4f\t%d\n' %
(epoch, train_loss, valid_auc, last_improvement))
epoch += 1
log.close()
if __name__ == "__main__":
main(sys.argv[1:])