-
Notifications
You must be signed in to change notification settings - Fork 216
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit a0264b5
Showing
14 changed files
with
1,582 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Generative Tensorial Reinforcement Learning (GENTRL) | ||
Supporting Information for the paper _"Deep learning enables rapid identification of potent DDR1 kinase inhibitors"_. | ||
|
||
The GENTRL model is a variational autoencoder with a rich prior distribution of the latent space. We used tensor decompositions to encode the relations between molecular structures and their properties and to learn on data with missing values. We train the model in two steps. First, we learn a mapping of a chemical space on the latent manifold by maximizing the evidence lower bound. We then freeze all the parameters except for the learnable prior and explore the chemical space to find molecules with a high reward. | ||
|
||
![GENTRL](images/gentrl.png) | ||
|
||
|
||
## Repository | ||
In this repository, we provide an implementation of a GENTRL model with an example trained on a [MOSES](https://github.com/molecularsets/moses) dataset. | ||
|
||
To run the training procedure, | ||
1. [Install RDKit](https://www.rdkit.org/docs/Install.html) to process molecules | ||
2. Install GENTRL model: `python setup.py install` | ||
3. Install MOSES from the [repository](https://github.com/molecularsets/moses) | ||
4. Run the [pretrain.ipynb](./examples/pretrain.ipynb) to train an autoencoder | ||
5. Run the [train_rl.ipynb](./examples/train_rl.ipynb) to optimize a reward function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import gentrl\n", | ||
"import torch\n", | ||
"import pandas as pd\n", | ||
"torch.cuda.set_device(0)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from moses.metrics import mol_passes_filters, QED, SA, logP\n", | ||
"from moses.metrics.utils import get_n_rings, get_mol\n", | ||
"\n", | ||
"\n", | ||
"def get_num_rings_6(mol):\n", | ||
" r = mol.GetRingInfo()\n", | ||
" return len([x for x in r.AtomRings() if len(x) > 6])\n", | ||
"\n", | ||
"\n", | ||
"def penalized_logP(mol_or_smiles, masked=False, default=-5):\n", | ||
" mol = get_mol(mol_or_smiles)\n", | ||
" if mol is None:\n", | ||
" return default\n", | ||
" reward = logP(mol) - SA(mol) - get_num_rings_6(mol)\n", | ||
" if masked and not mol_passes_filters(mol):\n", | ||
" return default\n", | ||
" return reward" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"! wget https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df = pd.read_csv('dataset_v1.csv')\n", | ||
"df = df[df['SPLIT'] == 'train']\n", | ||
"df['plogP'] = df['SMILES'].apply(penalized_logP)\n", | ||
"df.to_csv('train_plogp_plogpm.csv', index=None)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"enc = gentrl.RNNEncoder(latent_size=50)\n", | ||
"dec = gentrl.DilConvDecoder(latent_input_size=50)\n", | ||
"model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)\n", | ||
"model.cuda();" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"md = gentrl.MolecularDataset(sources=[\n", | ||
" {'path':'train_plogp_plogpm.csv',\n", | ||
" 'smiles': 'SMILES',\n", | ||
" 'prob': 1,\n", | ||
" 'plogP' : 'plogP',\n", | ||
" }], \n", | ||
" props=['plogP'])\n", | ||
"\n", | ||
"from torch.utils.data import DataLoader\n", | ||
"train_loader = DataLoader(md, batch_size=50, shuffle=True, num_workers=1, drop_last=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model.train_as_vaelp(train_loader, lr=1e-4)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"! mkdir -p saved_gentrl" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model.save('./saved_gentrl/')" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import gentrl\n", | ||
"import torch\n", | ||
"from rdkit.Chem import Draw\n", | ||
"from moses.metrics import mol_passes_filters, QED, SA, logP\n", | ||
"from moses.metrics.utils import get_n_rings, get_mol\n", | ||
"\n", | ||
"torch.cuda.set_device(0)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"enc = gentrl.RNNEncoder(latent_size=50)\n", | ||
"dec = gentrl.DilConvDecoder(latent_input_size=50)\n", | ||
"model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)\n", | ||
"model.cuda();" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model.load('saved_gentrl_after_rl/')\n", | ||
"model.cuda();" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_num_rings_6(mol):\n", | ||
" r = mol.GetRingInfo()\n", | ||
" return len([x for x in r.AtomRings() if len(x) > 6])\n", | ||
"\n", | ||
"\n", | ||
"def penalized_logP(mol_or_smiles, masked=True, default=-5):\n", | ||
" mol = get_mol(mol_or_smiles)\n", | ||
" if mol is None:\n", | ||
" return default\n", | ||
" reward = logP(mol) - SA(mol) - get_num_rings_6(mol)\n", | ||
" if masked and not mol_passes_filters(mol):\n", | ||
" return default\n", | ||
" return reward" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"generated = []\n", | ||
"\n", | ||
"while len(generated) < 1000:\n", | ||
" sampled = model.sample(100)\n", | ||
" sampled_valid = [s for s in sampled if get_mol(s)]\n", | ||
" \n", | ||
" generated += sampled_valid" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"Draw.MolsToGridImage([get_mol(s) for s in sampled_valid], \n", | ||
" legends=[str(penalized_logP(s)) for s in sampled_valid])" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import gentrl\n", | ||
"import torch\n", | ||
"torch.cuda.set_device(0)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"enc = gentrl.RNNEncoder(latent_size=50)\n", | ||
"dec = gentrl.DilConvDecoder(latent_input_size=50)\n", | ||
"model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)\n", | ||
"model.cuda();" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model.load('saved_gentrl/')\n", | ||
"model.cuda();" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from moses.metrics import mol_passes_filters, QED, SA, logP\n", | ||
"from moses.metrics.utils import get_n_rings, get_mol\n", | ||
"\n", | ||
"from moses.utils import disable_rdkit_log\n", | ||
"disable_rdkit_log()\n", | ||
"\n", | ||
"def get_num_rings_6(mol):\n", | ||
" r = mol.GetRingInfo()\n", | ||
" return len([x for x in r.AtomRings() if len(x) > 6])\n", | ||
"\n", | ||
"\n", | ||
"def penalized_logP(mol_or_smiles, masked=False, default=-5):\n", | ||
" mol = get_mol(mol_or_smiles)\n", | ||
" if mol is None:\n", | ||
" return default\n", | ||
" reward = logP(mol) - SA(mol) - get_num_rings_6(mol)\n", | ||
" if masked and not mol_passes_filters(mol):\n", | ||
" return default\n", | ||
" return reward" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model.train_as_rl(penalized_logP)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"! mkdir -p saved_gentrl_after_rl" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model.save('./saved_gentrl_after_rl/')" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .encoder import RNNEncoder | ||
from .decoder import DilConvDecoder | ||
from .gentrl import GENTRL | ||
from .dataloader import MolecularDataset | ||
|
||
|
||
__all__ = ['RNNEncoder', 'DilConvDecoder', 'GENTRL', 'MolecularDataset'] |
Oops, something went wrong.