Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add trvae Perturbation prediction notebook #243

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
remove functions from notebook
  • Loading branch information
chelseabright96 committed Jun 11, 2024
commit 085a48b4d1374c64721dde8e2a9add31c660b9e6
283 changes: 10 additions & 273 deletions notebooks/trvae_perturbation_prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -20,21 +20,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:root:In order to use the mouse gastrulation seqFISH datsets, please install squidpy (see https://github.com/scverse/squidpy).\n",
"WARNING:root:In order to use sagenet models, please install pytorch geometric (see https://pytorch-geometric.readthedocs.io) and \n",
" captum (see https://github.com/pytorch/captum).\n",
"WARNING:root:mvTCR is not installed. To use mvTCR models, please install it first using \"pip install mvtcr\"\n",
"WARNING:root:multigrate is not installed. To use multigrate models, please install it first using \"pip install multigrate\".\n"
]
}
],
"outputs": [],
"source": [
"import scanpy as sc\n",
"import torch\n",
Expand All @@ -43,11 +31,7 @@
"import gdown\n",
"from scipy.sparse import issparse\n",
"import numpy as np\n",
"import seaborn as sns\n",
"import pandas as pd\n",
"from matplotlib import pyplot\n",
"from scipy import stats\n",
"# from scarches.models.trvae._utils import reg_mean_plot, reg_var_plot"
"from scarches.models.trvae._utils import reg_mean_plot, reg_var_plot"
]
},
{
Expand All @@ -71,7 +55,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -252,7 +236,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -263,7 +247,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -286,7 +270,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -452,6 +436,7 @@
" source_adata = adata[adata.obs[condition_key] == source_cond]\n",
"\n",
" from scarches.dataset.trvae._utils import label_encoder\n",
"\n",
" encoder_labels = label_encoder(source_adata, model.model.condition_encoder, condition_key)\n",
" decoder_labels = np.zeros_like(encoder_labels) + model.model.condition_encoder[target_cond]\n",
"\n",
Expand Down Expand Up @@ -668,254 +653,6 @@
"Compare trVAE-predicted and real infected Tuft cells mean expression with the top 10 differentially expressed genes highlighted in red"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def reg_mean_plot(\n",
" adata,\n",
" condition_key,\n",
" target_condition,\n",
" labels,\n",
" path_to_save=\"./reg_mean.pdf\",\n",
" save=False,\n",
" gene_list=None,\n",
" show=False,\n",
" top_100_genes=None,\n",
" verbose=False,\n",
" legend=True,\n",
" title=None,\n",
" x_coeff=0.30,\n",
" y_coeff=0.8,\n",
" fontsize=14,\n",
" **kwargs,\n",
"):\n",
" \"\"\"\n",
" Plots mean matching figure for a set of specific genes.\n",
"\n",
" Parameters\n",
" ----------\n",
" adata: `~anndata.AnnData`\n",
" AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the\n",
" AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,\n",
" corresponding to batch and cell type metadata, respectively.\n",
" axis_keys: dict\n",
" Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:\n",
" `{\"x\": \"Key for x-axis\", \"y\": \"Key for y-axis\"}`.\n",
" labels: dict\n",
" Dictionary of axes labels of the form `{\"x\": \"x-axis-name\", \"y\": \"y-axis name\"}`.\n",
" path_to_save: basestring\n",
" path to save the plot.\n",
" save: boolean\n",
" Specify if the plot should be saved or not.\n",
" gene_list: list\n",
" list of gene names to be plotted.\n",
" show: bool\n",
" if `True`: will show to the plot after saving it.\n",
"\n",
" \"\"\"\n",
"\n",
" sns.set_theme()\n",
" sns.set_theme(color_codes=True)\n",
"\n",
" axis_keys = {\"x\":\"other\", \"y\":target_condition}\n",
"\n",
" diff_genes = top_100_genes\n",
" target_cd = adata[adata.obs[condition_key] == target_condition]\n",
" other_cd = adata[adata.obs[condition_key] != target_condition]\n",
" if diff_genes is not None:\n",
" if hasattr(diff_genes, \"tolist\"):\n",
" diff_genes = diff_genes.tolist()\n",
" adata_diff = adata[:, diff_genes]\n",
" target_diff = adata_diff[adata_diff.obs[condition_key] == target_condition]\n",
" other_diff = adata_diff[adata_diff.obs[condition_key] != target_condition]\n",
" x_diff = np.asarray(np.mean(target_diff.X, axis=0)).ravel()\n",
" y_diff = np.asarray(np.mean(other_diff.X, axis=0)).ravel()\n",
" m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(\n",
" x_diff, y_diff\n",
" )\n",
" if verbose:\n",
" print(\"top_100 DEGs mean: \", r_value_diff**2)\n",
" x = np.asarray(np.mean(other_cd.X, axis=0)).ravel()\n",
" y = np.asarray(np.mean(target_cd.X, axis=0)).ravel()\n",
" m, b, r_value, p_value, std_err = stats.linregress(x, y)\n",
" if verbose: \n",
" print(\"All genes mean: \", r_value**2)\n",
" df = pd.DataFrame({axis_keys[\"x\"]: x, axis_keys[\"y\"]: y})\n",
" ax = sns.regplot(x=axis_keys[\"x\"], y=axis_keys[\"y\"], data=df)\n",
" ax.tick_params(labelsize=fontsize)\n",
" if \"range\" in kwargs:\n",
" start, stop, step = kwargs.get(\"range\")\n",
" ax.set_xticks(np.arange(start, stop, step))\n",
" ax.set_yticks(np.arange(start, stop, step))\n",
" ax.set_xlabel(labels[\"x\"], fontsize=fontsize)\n",
" ax.set_ylabel(labels[\"y\"], fontsize=fontsize)\n",
" if gene_list is not None:\n",
" texts = []\n",
" for i in gene_list:\n",
" j = adata.var_names.tolist().index(i)\n",
" x_bar = x[j]\n",
" y_bar = y[j]\n",
" texts.append(pyplot.text(x_bar, y_bar, i, fontsize=11, color=\"black\"))\n",
" pyplot.plot(x_bar, y_bar, \"o\", color=\"red\", markersize=5)\n",
"\n",
" if legend:\n",
" pyplot.legend(loc=\"center left\", bbox_to_anchor=(1, 0.5))\n",
" if title is None:\n",
" pyplot.title(\"\", fontsize=fontsize)\n",
" else:\n",
" pyplot.title(title, fontsize=fontsize)\n",
" ax.text(\n",
" max(x) - max(x) * x_coeff,\n",
" max(y) - y_coeff * max(y),\n",
" r\"$\\mathrm{R^2_{\\mathrm{\\mathsf{all\\ genes}}}}$= \" + f\"{r_value ** 2:.2f}\",\n",
" fontsize=kwargs.get(\"textsize\", fontsize),\n",
" )\n",
" if diff_genes is not None:\n",
" ax.text(\n",
" max(x) - max(x) * x_coeff,\n",
" max(y) - (y_coeff + 0.15) * max(y),\n",
" r\"$\\mathrm{R^2_{\\mathrm{\\mathsf{top\\ 100\\ DEGs}}}}$= \"\n",
" + f\"{r_value_diff ** 2:.2f}\",\n",
" fontsize=kwargs.get(\"textsize\", fontsize),\n",
" )\n",
" if save:\n",
" pyplot.savefig(f\"{path_to_save}\", bbox_inches=\"tight\", dpi=100)\n",
" if show:\n",
" pyplot.show()\n",
" pyplot.close()\n",
" if diff_genes is not None:\n",
" return r_value**2, r_value_diff**2\n",
" else:\n",
" return r_value**2"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"def reg_var_plot(\n",
" adata,\n",
" condition_key,\n",
" target_condition,\n",
" labels,\n",
" path_to_save=\"./reg_var.pdf\",\n",
" save=False,\n",
" gene_list=None,\n",
" show=False,\n",
" top_100_genes=None,\n",
" verbose=False,\n",
" legend=True,\n",
" title=None,\n",
" x_coeff=0.30,\n",
" y_coeff=0.8,\n",
" fontsize=14,\n",
" **kwargs,\n",
"):\n",
" \"\"\"\n",
" Plots variance matching figure for a set of specific genes.\n",
"\n",
" Parameters\n",
" ----------\n",
" adata: `~anndata.AnnData`\n",
" AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the\n",
" AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`,\n",
" corresponding to batch and cell type metadata, respectively.\n",
" axis_keys: dict\n",
" Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form:\n",
" `{\"x\": \"Key for x-axis\", \"y\": \"Key for y-axis\"}`.\n",
" labels: dict\n",
" Dictionary of axes labels of the form `{\"x\": \"x-axis-name\", \"y\": \"y-axis name\"}`.\n",
" path_to_save: basestring\n",
" path to save the plot.\n",
" save: boolean\n",
" Specify if the plot should be saved or not.\n",
" gene_list: list\n",
" list of gene names to be plotted.\n",
" show: bool\n",
" if `True`: will show to the plot after saving it.\n",
"\n",
" \"\"\"\n",
"\n",
" sns.set_theme()\n",
" sns.set_theme(color_codes=True)\n",
"\n",
" axis_keys = {\"x\":\"other\", \"y\":target_condition}\n",
"\n",
" diff_genes = top_100_genes\n",
" target_cd = adata[adata.obs[condition_key] == target_condition]\n",
" other_cd = adata[adata.obs[condition_key] != target_condition]\n",
" if diff_genes is not None:\n",
" if hasattr(diff_genes, \"tolist\"):\n",
" diff_genes = diff_genes.tolist()\n",
" adata_diff = adata[:, diff_genes]\n",
" target_diff = adata_diff[adata_diff.obs[condition_key] == target_condition]\n",
" other_diff = adata_diff[adata_diff.obs[condition_key] != target_condition]\n",
" x_diff = np.asarray(np.var(target_diff.X, axis=0)).ravel()\n",
" y_diff = np.asarray(np.var(other_diff.X, axis=0)).ravel()\n",
" m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(\n",
" x_diff, y_diff\n",
" )\n",
" if verbose:\n",
" print(\"top_100 DEGs var: \", r_value_diff**2)\n",
" x = np.asarray(np.var(other_cd.X, axis=0)).ravel()\n",
" y = np.asarray(np.var(target_cd.X, axis=0)).ravel()\n",
" m, b, r_value, p_value, std_err = stats.linregress(x, y)\n",
" if verbose: \n",
" print(\"All genes var: \", r_value**2)\n",
" df = pd.DataFrame({axis_keys[\"x\"]: x, axis_keys[\"y\"]: y})\n",
" ax = sns.regplot(x=axis_keys[\"x\"], y=axis_keys[\"y\"], data=df)\n",
" ax.tick_params(labelsize=fontsize)\n",
" if \"range\" in kwargs:\n",
" start, stop, step = kwargs.get(\"range\")\n",
" ax.set_xticks(np.arange(start, stop, step))\n",
" ax.set_yticks(np.arange(start, stop, step))\n",
" ax.set_xlabel(labels[\"x\"], fontsize=fontsize)\n",
" ax.set_ylabel(labels[\"y\"], fontsize=fontsize)\n",
" if gene_list is not None:\n",
" texts = []\n",
" for i in gene_list:\n",
" j = adata.var_names.tolist().index(i)\n",
" x_bar = x[j]\n",
" y_bar = y[j]\n",
" texts.append(pyplot.text(x_bar, y_bar, i, fontsize=11, color=\"black\"))\n",
" pyplot.plot(x_bar, y_bar, \"o\", color=\"red\", markersize=5)\n",
"\n",
" if legend:\n",
" pyplot.legend(loc=\"center left\", bbox_to_anchor=(1, 0.5))\n",
" if title is None:\n",
" pyplot.title(\"\", fontsize=fontsize)\n",
" else:\n",
" pyplot.title(title, fontsize=fontsize)\n",
" ax.text(\n",
" max(x) - max(x) * x_coeff,\n",
" max(y) - y_coeff * max(y),\n",
" r\"$\\mathrm{R^2_{\\mathrm{\\mathsf{all\\ genes}}}}$= \" + f\"{r_value ** 2:.2f}\",\n",
" fontsize=kwargs.get(\"textsize\", fontsize),\n",
" )\n",
" if diff_genes is not None:\n",
" ax.text(\n",
" max(x) - max(x) * x_coeff,\n",
" max(y) - (y_coeff + 0.15) * max(y),\n",
" r\"$\\mathrm{R^2_{\\mathrm{\\mathsf{top\\ 100\\ DEGs}}}}$= \"\n",
" + f\"{r_value_diff ** 2:.2f}\",\n",
" fontsize=kwargs.get(\"textsize\", fontsize),\n",
" )\n",
" if save:\n",
" pyplot.savefig(f\"{path_to_save}\", bbox_inches=\"tight\", dpi=100)\n",
" if show:\n",
" pyplot.show()\n",
" pyplot.close()\n",
" if diff_genes is not None:\n",
" return r_value**2, r_value_diff**2\n",
" else:\n",
" return r_value**2\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
Expand Down Expand Up @@ -999,7 +736,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.9.19"
}
},
"nbformat": 4,
Expand Down