-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
259 lines (218 loc) · 8.73 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import yaml
import os
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from rdkit import Chem
import selfies as sf
from fvcore.nn import FlopCountAnalysis, parameter_count_table
import pickle
from dataloader import dataloader_gen
from dataloader import SELFIEVocab, RegExVocab, CharVocab
from model import RNN
import argparse
# suppress rdkit error
from rdkit import rdBase
rdBase.DisableLog('rdApp.error')
def make_vocab(config):
# load vocab
which_vocab = config.which_vocab
vocab_path = config.vocab_path
if which_vocab == "selfies":
return SELFIEVocab(vocab_path)
elif which_vocab == "regex":
return RegExVocab(vocab_path)
elif which_vocab == "char":
return CharVocab(vocab_path)
else:
raise ValueError(
"Wrong vacab name for configuration which_vocab!"
)
def sample(model, vocab, batch_size):
"""Sample a batch of SMILES from current model."""
model.eval()
# sample
sampled_ints = model.sample(
batch_size=batch_size,
vocab=vocab,
device=device
)
# convert integers back to SMILES
molecules = []
sampled_ints = sampled_ints.tolist()
for ints in sampled_ints:
molecule = []
for x in ints:
if vocab.int2tocken[x] == '<eos>':
break
else:
molecule.append(vocab.int2tocken[x])
molecules.append("".join(molecule))
# convert SELFIES back to SMILES
if vocab.name == 'selfies':
molecules = [sf.decoder(x) for x in molecules]
return molecules
def compute_valid_rate(molecules):
"""compute the percentage of valid SMILES given
a list SMILES strings"""
num_valid, num_invalid = 0, 0
for mol in molecules:
mol = Chem.MolFromSmiles(mol)
if mol is None:
num_invalid += 1
else:
num_valid += 1
return num_valid, num_invalid
if __name__ == "__main__":
#Args
parser = argparse.ArgumentParser()
parser.add_argument('--out_dir', type=str, default='./model_parameters')
parser.add_argument('--dataset_dir', type=str, default="./chembl/database_smiles_0.5.pkl")
#vocab choosing - "selfies", "regex" , "DeepSMILES"
parser.add_argument('--which_vocab', type=str, default="selfies")
parser.add_argument('--vocab_path', type=str, default="./vocab/chembl_selfies_vocab.yaml")
parser.add_argument('--percentage', type=float, default=1)
#RNN config
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--rnn_type', type=str, default='GRU')
#SELFIES - 148 , regex - 101, DeepSMILES - 129
parser.add_argument('--num_embeddings', type=int, default=148)
parser.add_argument('--embedding_dim', type=int, default=512)
parser.add_argument('--input_size', type=int, default=512)
parser.add_argument('--hidden_size', type=int, default=512)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--dropout', type=int, default=0)
parser.add_argument('--shuffle', type=bool, default=True)
parser.add_argument('--num_epoch', type=int, default=10)
parser.add_argument('--which_optimizer', type=str, default='adam')
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=1e-4)
config = parser.parse_args()
# detect cpu or gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device: ', device)
out_dir = config.out_dir
if not os.path.exists(out_dir):
os.makedirs(out_dir)
# training data
dataset_dir = config.dataset_dir
which_vocab = config.which_vocab
vocab_path = config.vocab_path
percentage = config.percentage
# create dataloader
batch_size = config.batch_size
shuffle = config.shuffle
num_workers = os.cpu_count()
print('number of workers to load data: ', num_workers)
print('which vocabulary to use: ', which_vocab)
dataloader, train_size = dataloader_gen(
dataset_dir, percentage, which_vocab,
vocab_path, batch_size,shuffle, drop_last=False,mission_type= 'train'
)
# model and training configuration
rnn_config = {'num_embeddings': config.num_embeddings, 'embedding_dim': config.embedding_dim,
'rnn_type': config.rnn_type, 'input_size': config.input_size,
'hidden_size': config.hidden_size, 'num_layers': config.num_layers, 'dropout': config.dropout}
model = RNN(rnn_config).to(device)
#show model parameters and structures
print(parameter_count_table(model))
learning_rate = config.learning_rate
weight_decay = config.weight_decay
# Making reduction="sum" makes huge difference
# in valid rate of sampled molecules.
loss_function = nn.CrossEntropyLoss(reduction='sum')
# create optimizer
if config.which_optimizer == "adam":
optimizer = torch.optim.Adam(
model.parameters(), lr=learning_rate,
weight_decay=weight_decay, amsgrad=True
)
elif config.which_optimizer == "sgd":
optimizer = torch.optim.SGD(
model.parameters(), lr=learning_rate,
weight_decay=weight_decay, momentum=0.9
)
else:
raise ValueError(
"Wrong optimizer! Select between 'adam' and 'sgd'."
)
# learning rate scheduler
scheduler = ReduceLROnPlateau(
optimizer, mode='min',
factor=0.5, patience=5,
cooldown=10, min_lr=0.0001,
verbose=True
)
# vocabulary object used by the sample() function
vocab = make_vocab(config)
# train and validation, the results are saved.
train_losses = []
best_valid_rate = 0
num_epoch = config.num_epoch
if os.path.exists(os.path.join(config.out_dir, 'trained_model.pth')):
path = os.path.join(config.out_dir, 'trained_model.pth')
try:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint)
start_epoch = checkpoint['epoch']
print(f'Continue training with file : {path}')
except BaseException:
print('Parameters mismatch, start de novo training!')
start_epoch = 0
else:
start_epoch = 0
print('begin training...')
for epoch in range(start_epoch+1, 1 + num_epoch):
model.train()
train_loss = 0
for data, lengths in tqdm(dataloader):
# the lengths are decreased by 1 because we don't
# use <eos> for input and we don't need <sos> for
# output during traning.
lengths = [length - 1 for length in lengths]
optimizer.zero_grad()
data = data.to(device)
preds = model(data, lengths)
# The <sos> token is removed before packing, because
# we don't need <sos> of output during training.
# the image_captioning project uses the same method
# which directly feeds the packed sequences to
# the loss function:
# https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/train.py
targets = pack_padded_sequence(
data[:, 1:],
lengths,
batch_first=True,
enforce_sorted=False
).data
loss = loss_function(preds, targets)
print(f"loss:{loss}")
loss.backward()
optimizer.step()
# accumulate loss over mini-batches
train_loss += loss.item() # * data.size()[0]
train_losses.append(train_loss / train_size)
print('epoch {}, train loss: {}.'.format(epoch, train_losses[-1]))
scheduler.step(train_losses[-1])
# sample 1024 SMILES each epoch
sampled_molecules = sample(model, vocab, batch_size=1024)
# print the valid rate each epoch
num_valid, num_invalid = compute_valid_rate(sampled_molecules)
valid_rate = num_valid / (num_valid + num_invalid)
print('valid rate: {}'.format(valid_rate))
torch.save(model.state_dict(), os.path.join(out_dir,f'epoch{epoch}.pth'))
# update the saved model upon best validation loss
if valid_rate >= best_valid_rate:
best_valid_rate = valid_rate
print('model saved at epoch {}'.format(epoch))
checkpoint = {
"net": model.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch
}
torch.save(checkpoint, os.path.join(out_dir, 'trained_model.pth'))
# save train and validation losses
with open(out_dir + 'loss.yaml', 'w') as f:
yaml.dump(train_losses, f)