Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Implement Likelihood-Ratio test #178

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
implemented lrt test
  • Loading branch information
PABannier committed Oct 17, 2023
commit 4a663d0e7a964863f89e3b5cffdcf8e6004fffd6
64 changes: 63 additions & 1 deletion pydeseq2/ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple

# import anndata as ad
import numpy as np
Expand All @@ -20,6 +21,7 @@
from pydeseq2.utils import get_num_processes
from pydeseq2.utils import make_MA_plot
from pydeseq2.utils import nbinomGLM
from pydeseq2.utils import lrt_test
from pydeseq2.utils import wald_test


Expand Down Expand Up @@ -372,7 +374,67 @@ def run_likelihood_ratio_test(self) -> None:

Get gene-wise p-values for gene over/under-expression.
"""
raise NotImplementedError

num_genes = self.dds.n_vars
num_vars = self.design_matrix.shape[1]

# XXX: Raise a warning if LFCs are shrunk.

def reduce(
design_matrix: np.ndarray, ridge_factor: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
indices = np.full(design_matrix.shape[1], True, dtype=bool)
indices[self.contrast_idx] = False
return design_matrix[:, indices], ridge_factor[indices]

# Set regularization factors.
if self.prior_LFC_var is not None:
ridge_factor = np.diag(1 / self.prior_LFC_var**2)
else:
ridge_factor = np.diag(np.repeat(1e-6, num_vars))

design_matrix = self.design_matrix.values
LFCs = self.LFC.values

reduced_design_matrix, reduced_ridge_factor = reduce(design_matrix, ridge_factor)
self.dds.obsm["reduced_design_matrix"] = reduced_design_matrix

if not self.quiet:
print("Running LRT tests...", file=sys.stderr)
start = time.time()
with parallel_backend("loky", inner_max_num_threads=1):
res = Parallel(
n_jobs=self.n_processes,
verbose=self.joblib_verbosity,
batch_size=self.batch_size,
)(
delayed(lrt_test)(
counts=self.dds.X[:, i],
design_matrix=design_matrix,
reduced_design_matrix=reduced_design_matrix,
size_factors=self.dds.obsm["size_factors"],
disp=self.dds.varm["dispersions"][i],
lfc=LFCs[i],
min_mu=self.dds.min_mu,
ridge_factor=ridge_factor,
reduced_ridge_factor=reduced_ridge_factor,
beta_tol=self.dds.beta_tol,
)
for i in range(num_genes)
)
end = time.time()
if not self.quiet:
print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr)

pvals, stats = zip(*res)

self.p_values: pd.Series = pd.Series(pvals, index=self.dds.var_names)
self.statistics: pd.Series = pd.Series(stats, index=self.dds.var_names)

# Account for possible all_zeroes due to outlier refitting in DESeqDataSet
if self.dds.refit_cooks and self.dds.varm["replaced"].sum() > 0:
self.statistics.loc[self.dds.new_all_zeroes_genes] = 0.0
self.p_values.loc[self.dds.new_all_zeroes_genes] = 1.0

def lfc_shrink(self, coeff: Optional[str] = None) -> None:
"""LFC shrinkage with an apeGLM prior :cite:p:`DeseqStats-zhu2019heavy`.
Expand Down
56 changes: 56 additions & 0 deletions pydeseq2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from scipy.special import gammaln # type: ignore
from scipy.special import polygamma # type: ignore
from scipy.stats import norm # type: ignore
from scipy.stats import chi2 # type: ignore
from sklearn.linear_model import LinearRegression # type: ignore

import pydeseq2
Expand Down Expand Up @@ -979,6 +980,61 @@ def less_abs(lfc_null):
return wald_p_value, wald_statistic, wald_se


def lrt_test(
counts: np.ndarray,
design_matrix: np.ndarray,
reduced_design_matrix: np.ndarray,
size_factors: np.ndarray,
disp: float,
lfc: np.ndarray,
min_mu: float,
ridge_factor: np.ndarray,
reduced_ridge_factor: np.ndarray,
beta_tol: float,
) -> Tuple[float, float]:
"""Run likelihood ratio test for differential expression.

Compute likelihood ratio test statistics and p-values from
dispersion and LFC estimates.

Parameters
----------

Returns
-------
lrt_p_value : float
Estimated p-value.

lrt_statistic : float
LRT statistic.
"""
def reg_nb_nll(
beta: np.ndarray, design_matrix: np.ndarray, ridge_factor: np.ndarray
) -> float:
# closure to minimize
mu_ = np.maximum(size_factors * np.exp(design_matrix @ beta), min_mu)
val = nb_nll(counts, mu_, disp) + 0.5 * (ridge_factor @ beta**2).sum()
return -1.0 * val # maximize the likelihood

beta_reduced, *_ = irls_solver(
counts=counts,
size_factors=size_factors,
design_matrix=reduced_design_matrix,
disp=disp,
min_mu=min_mu,
beta_tol=beta_tol,
)

reduced_ll = reg_nb_nll(beta_reduced, reduced_design_matrix, reduced_ridge_factor)
full_ll = reg_nb_nll(lfc, design_matrix, ridge_factor)

lrt_statistic = 2 * (full_ll - reduced_ll)
# df = 1 since contrast_idx is the only variable removed
lrt_p_value = chi2.sf(lrt_statistic, df=1)

return lrt_p_value, lrt_statistic


def fit_rough_dispersions(
normed_counts: np.ndarray, design_matrix: pd.DataFrame
) -> np.ndarray:
Expand Down
Loading