Skip to content

Commit

Permalink
GENTRL
Browse files Browse the repository at this point in the history
  • Loading branch information
danpol committed Sep 1, 2019
0 parents commit a0264b5
Show file tree
Hide file tree
Showing 14 changed files with 1,582 additions and 0 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Generative Tensorial Reinforcement Learning (GENTRL)
Supporting Information for the paper _"Deep learning enables rapid identification of potent DDR1 kinase inhibitors"_.

The GENTRL model is a variational autoencoder with a rich prior distribution of the latent space. We used tensor decompositions to encode the relations between molecular structures and their properties and to learn on data with missing values. We train the model in two steps. First, we learn a mapping of a chemical space on the latent manifold by maximizing the evidence lower bound. We then freeze all the parameters except for the learnable prior and explore the chemical space to find molecules with a high reward.

![GENTRL](images/gentrl.png)


## Repository
In this repository, we provide an implementation of a GENTRL model with an example trained on a [MOSES](https://github.com/molecularsets/moses) dataset.

To run the training procedure,
1. [Install RDKit](https://www.rdkit.org/docs/Install.html) to process molecules
2. Install GENTRL model: `python setup.py install`
3. Install MOSES from the [repository](https://github.com/molecularsets/moses)
4. Run the [pretrain.ipynb](./examples/pretrain.ipynb) to train an autoencoder
5. Run the [train_rl.ipynb](./examples/train_rl.ipynb) to optimize a reward function
140 changes: 140 additions & 0 deletions examples/pretrain.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import gentrl\n",
"import torch\n",
"import pandas as pd\n",
"torch.cuda.set_device(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from moses.metrics import mol_passes_filters, QED, SA, logP\n",
"from moses.metrics.utils import get_n_rings, get_mol\n",
"\n",
"\n",
"def get_num_rings_6(mol):\n",
" r = mol.GetRingInfo()\n",
" return len([x for x in r.AtomRings() if len(x) > 6])\n",
"\n",
"\n",
"def penalized_logP(mol_or_smiles, masked=False, default=-5):\n",
" mol = get_mol(mol_or_smiles)\n",
" if mol is None:\n",
" return default\n",
" reward = logP(mol) - SA(mol) - get_num_rings_6(mol)\n",
" if masked and not mol_passes_filters(mol):\n",
" return default\n",
" return reward"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! wget https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('dataset_v1.csv')\n",
"df = df[df['SPLIT'] == 'train']\n",
"df['plogP'] = df['SMILES'].apply(penalized_logP)\n",
"df.to_csv('train_plogp_plogpm.csv', index=None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"enc = gentrl.RNNEncoder(latent_size=50)\n",
"dec = gentrl.DilConvDecoder(latent_input_size=50)\n",
"model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)\n",
"model.cuda();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"md = gentrl.MolecularDataset(sources=[\n",
" {'path':'train_plogp_plogpm.csv',\n",
" 'smiles': 'SMILES',\n",
" 'prob': 1,\n",
" 'plogP' : 'plogP',\n",
" }], \n",
" props=['plogP'])\n",
"\n",
"from torch.utils.data import DataLoader\n",
"train_loader = DataLoader(md, batch_size=50, shuffle=True, num_workers=1, drop_last=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.train_as_vaelp(train_loader, lr=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! mkdir -p saved_gentrl"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.save('./saved_gentrl/')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
108 changes: 108 additions & 0 deletions examples/sampling.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import gentrl\n",
"import torch\n",
"from rdkit.Chem import Draw\n",
"from moses.metrics import mol_passes_filters, QED, SA, logP\n",
"from moses.metrics.utils import get_n_rings, get_mol\n",
"\n",
"torch.cuda.set_device(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"enc = gentrl.RNNEncoder(latent_size=50)\n",
"dec = gentrl.DilConvDecoder(latent_input_size=50)\n",
"model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)\n",
"model.cuda();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.load('saved_gentrl_after_rl/')\n",
"model.cuda();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_num_rings_6(mol):\n",
" r = mol.GetRingInfo()\n",
" return len([x for x in r.AtomRings() if len(x) > 6])\n",
"\n",
"\n",
"def penalized_logP(mol_or_smiles, masked=True, default=-5):\n",
" mol = get_mol(mol_or_smiles)\n",
" if mol is None:\n",
" return default\n",
" reward = logP(mol) - SA(mol) - get_num_rings_6(mol)\n",
" if masked and not mol_passes_filters(mol):\n",
" return default\n",
" return reward"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"generated = []\n",
"\n",
"while len(generated) < 1000:\n",
" sampled = model.sample(100)\n",
" sampled_valid = [s for s in sampled if get_mol(s)]\n",
" \n",
" generated += sampled_valid"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Draw.MolsToGridImage([get_mol(s) for s in sampled_valid], \n",
" legends=[str(penalized_logP(s)) for s in sampled_valid])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
112 changes: 112 additions & 0 deletions examples/train_rl.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import gentrl\n",
"import torch\n",
"torch.cuda.set_device(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"enc = gentrl.RNNEncoder(latent_size=50)\n",
"dec = gentrl.DilConvDecoder(latent_input_size=50)\n",
"model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)\n",
"model.cuda();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.load('saved_gentrl/')\n",
"model.cuda();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from moses.metrics import mol_passes_filters, QED, SA, logP\n",
"from moses.metrics.utils import get_n_rings, get_mol\n",
"\n",
"from moses.utils import disable_rdkit_log\n",
"disable_rdkit_log()\n",
"\n",
"def get_num_rings_6(mol):\n",
" r = mol.GetRingInfo()\n",
" return len([x for x in r.AtomRings() if len(x) > 6])\n",
"\n",
"\n",
"def penalized_logP(mol_or_smiles, masked=False, default=-5):\n",
" mol = get_mol(mol_or_smiles)\n",
" if mol is None:\n",
" return default\n",
" reward = logP(mol) - SA(mol) - get_num_rings_6(mol)\n",
" if masked and not mol_passes_filters(mol):\n",
" return default\n",
" return reward"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.train_as_rl(penalized_logP)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! mkdir -p saved_gentrl_after_rl"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.save('./saved_gentrl_after_rl/')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
7 changes: 7 additions & 0 deletions gentrl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .encoder import RNNEncoder
from .decoder import DilConvDecoder
from .gentrl import GENTRL
from .dataloader import MolecularDataset


__all__ = ['RNNEncoder', 'DilConvDecoder', 'GENTRL', 'MolecularDataset']
Loading

0 comments on commit a0264b5

Please sign in to comment.