-
Notifications
You must be signed in to change notification settings - Fork 6
/
decoding.py
30 lines (27 loc) · 968 Bytes
/
decoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright (C) 2019 Computational Science Lab, UPF <https://www.compscience.org/>
# Copying and distribution is allowed under AGPLv3 license
vocab_list = ["pad", "start", "end",
"C", "c", "N", "n", "S", "s", "P", "p", "O", "o",
"B", "F", "I",
"X", "Y", "Z", #"Cl", "[nH]", "Br"
"1", "2", "3", "4", "5", "6",
"#", "=", "-", "(", ")","/","\\","@","[","]","H","+","7" # Misc
]
vocab_i2c_v1 = {i: x for i, x in enumerate(vocab_list)}
vocab_c2i_v1 = {vocab_i2c_v1[i]: i for i in vocab_i2c_v1}
def decode_smiles(in_tensor):
"""
Decodes input tensor to a list of strings.
:param in_tensor:
:return:
"""
gen_smiles = []
for sample in in_tensor:
csmile = ""
for xchar in sample[1:]:
if xchar == 2:
break
csmile += vocab_i2c_v1[xchar]
csmile = csmile.replace("X","Cl").replace("Y","[nH]").replace("Z","Br")
gen_smiles.append(csmile)
return gen_smiles