Skip to content

Commit

Permalink
BUG fix size factor fitting when zero in each gene (owkin#98)
Browse files Browse the repository at this point in the history
* docs: correct type hint for vec_nb_nll

* feat: implement mean absolute deviation

* feat: implement iterative size factor fitting method

* docs: remove undefined reference

* test: add test for iterative size factors

* fix: correct nb_nll when several dispersions are provided

* fix: use Powell minimization for iterative size factors

* test: add correctness test for iterative size factors

* refactor: move the definition of the size factor objective function outside of the loop

* typo

Co-authored-by: Maria Telenczuk <[email protected]>

* docs: improve fit_size_factors() docstring

* refactor: raise a warning when each gene has a zero instead of a print statement

* test: catch RuntimeWarning when each gene has a sample with zero counts

---------

Co-authored-by: Maria Telenczuk <[email protected]>
  • Loading branch information
BorisMuzellec and maikia authored Mar 24, 2023
1 parent 1974042 commit 67a9f02
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 24 deletions.
122 changes: 113 additions & 9 deletions pydeseq2/dds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
import warnings
from typing import List
from typing import Literal
from typing import Optional
from typing import Union
from typing import cast
Expand All @@ -12,9 +13,9 @@
from joblib import Parallel # type: ignore
from joblib import delayed
from joblib import parallel_backend
from scipy.optimize import minimize
from scipy.special import polygamma # type: ignore
from scipy.stats import f # type: ignore
from scipy.stats import norm
from statsmodels.tools.sm_exceptions import DomainWarning # type: ignore

from pydeseq2.preprocessing import deseq2_norm
Expand All @@ -26,6 +27,8 @@
from pydeseq2.utils import fit_rough_dispersions
from pydeseq2.utils import get_num_processes
from pydeseq2.utils import irls_solver
from pydeseq2.utils import mean_absolute_deviation
from pydeseq2.utils import nb_nll
from pydeseq2.utils import robust_method_of_moments_disp
from pydeseq2.utils import test_valid_counts
from pydeseq2.utils import trimmed_mean
Expand Down Expand Up @@ -288,14 +291,36 @@ def deseq2(self) -> None:
# for genes that had outliers replaced
self.refit()

def fit_size_factors(self) -> None:
def fit_size_factors(
self, fit_type: Literal["ratio", "iterative"] = "ratio"
) -> None:
"""Fit sample-wise deseq2 normalization (size) factors.
Uses the median-of-ratios method: see :func:`pydeseq2.preprocessing.deseq2_norm`.
Uses the median-of-ratios method: see :func:`pydeseq2.preprocessing.deseq2_norm`,
unless each gene has at least one sample with zero read counts, in which case it
switches to the ``iterative`` method.
Parameters
----------
fit_type : str
The normalization method to use (default: ``"ratio"``).
"""
print("Fitting size factors...")
start = time.time()
self.layers["normed_counts"], self.obsm["size_factors"] = deseq2_norm(self.X)
if fit_type == "iterative":
self._fit_iterate_size_factors()
# Test whether it is possible to use median-of-ratios.
elif (self.X == 0).any(0).all():
# There is at least a zero for each gene
warnings.warn(
"Every gene contains at least one zero, "
"cannot compute log geometric means. Switching to iterative mode.",
RuntimeWarning,
stacklevel=2,
)
self._fit_iterate_size_factors()
else:
self.layers["normed_counts"], self.obsm["size_factors"] = deseq2_norm(self.X)
end = time.time()
print(f"... done in {end - start:.2f} seconds.\n")

Expand Down Expand Up @@ -484,7 +509,7 @@ def fit_dispersion_prior(self) -> None:
"""

# Check that the dispersion trend curve was fitted. If not, fit it.
if "trend_coeffs" not in self.uns:
if "fitted_dispersions" not in self.varm:
self.fit_dispersion_trend()

# Exclude genes with all zeroes
Expand All @@ -502,9 +527,9 @@ def fit_dispersion_prior(self) -> None:
100 * self.min_disp
)

self.uns["_squared_logres"] = np.median(
np.abs(disp_residuals[above_min_disp])
) ** 2 / norm.ppf(0.75)
self.uns["_squared_logres"] = (
mean_absolute_deviation(disp_residuals[above_min_disp]) ** 2
)
self.uns["prior_disp_var"] = np.maximum(
self.uns["_squared_logres"] - polygamma(1, (num_genes - num_vars) / 2),
0.25,
Expand Down Expand Up @@ -559,7 +584,7 @@ def fit_MAP_dispersions(self) -> None:

# Filter outlier genes for which we won't apply shrinkage
self.varm["dispersions"] = self.varm["MAP_dispersions"].copy()
self.varm["_outlier_genes"] = np.log(self.varm["dispersions"]) > np.log(
self.varm["_outlier_genes"] = np.log(self.varm["genewise_dispersions"]) > np.log(
self.varm["fitted_dispersions"]
) + 2 * np.sqrt(self.uns["_squared_logres"])
self.varm["dispersions"][self.varm["_outlier_genes"]] = self.varm[
Expand Down Expand Up @@ -854,3 +879,82 @@ def _refit_without_outliers(
new_all_zeroes.sum()
)
self[:, self.new_all_zeroes_genes].varm["LFC"] = np.zeros(new_all_zeroes.sum())

def _fit_iterate_size_factors(self, niter: int = 10, quant: float = 0.95) -> None:
"""
Fit size factors using the ``iterative`` method.
Used when each gene has at least one zero.
Parameters
----------
niter : int
Maximum number of iterations to perform (default: ``10``).
quant : float
Quantile value at which negative likelihood is cut in the optimization
(default: ``0.95``).
"""

self.obsm["size_factors"] = np.ones(self.n_obs)

# Reduce the design matrix to an intercept and reconstruct at the end
self.obsm["design_matrix_buffer"] = self.obsm["design_matrix"].copy()
self.obsm["design_matrix"] = pd.DataFrame(
1, index=self.obs_names, columns=[["intercept"]]
)

# Fit size factors using MLE
def objective(p):
sf = np.exp(p - np.mean(p))
nll = nb_nll(
counts=self[:, self.non_zero_genes].X,
mu=self[:, self.non_zero_genes].layers["_mu_hat"]
/ self.obsm["size_factors"][:, None]
* sf[:, None],
alpha=self[:, self.non_zero_genes].varm["dispersions"],
)
# Take out the lowest likelihoods (highest neg) from the sum
return np.sum(nll[nll < np.quantile(nll, quant)])

for i in range(niter):
# Estimate dispersions based on current size factors
self.fit_genewise_dispersions()

# Use a mean trend curve
use_for_mean_genes = self.var_names[
(self.varm["genewise_dispersions"] > 10 * self.min_disp)
& self.varm["non_zero"]
]

mean_disp = trimmed_mean(
self[:, use_for_mean_genes].varm["genewise_dispersions"], trim=0.001
)
self.varm["fitted_dispersions"] = np.ones(self.n_vars) * mean_disp
self.fit_dispersion_prior()
self.fit_MAP_dispersions()
old_sf = self.obsm["size_factors"].copy()

# Fit size factors using MLE
res = minimize(objective, np.log(old_sf), method="Powell")

self.obsm["size_factors"] = np.exp(res.x - np.mean(res.x))

if not res.success:
print("A size factor fitting iteration failed.")
break

if (i > 1) and np.sum(
(np.log(old_sf) - np.log(self.obsm["size_factors"])) ** 2
) < 1e-4:
break
elif i == niter - 1:
print("Iterative size factor fitting did not converge.")

# Restore the design matrix and free buffer
self.obsm["design_matrix"] = self.obsm["design_matrix_buffer"].copy()
del self.obsm["design_matrix_buffer"]

# Store normalized counts
self.layers["normed_counts"] = self.X / self.obsm["size_factors"][:, None]
4 changes: 2 additions & 2 deletions pydeseq2/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pydeseq2.utils


def vec_nb_nll(counts: np.ndarray, mu: np.ndarray, alpha: float) -> np.ndarray:
def vec_nb_nll(counts: np.ndarray, mu: np.ndarray, alpha: np.ndarray) -> np.ndarray:
"""Return the negative log-likelihood of a negative binomial.
Vectorized version.
Expand All @@ -19,7 +19,7 @@ def vec_nb_nll(counts: np.ndarray, mu: np.ndarray, alpha: float) -> np.ndarray:
mu : ndarray
Mean of the distribution.
alpha : float
alpha : ndarray
Dispersion of the distribution, s.t. the variance is
:math:`\\mu + \\alpha * \\mu^2`.
Expand Down
54 changes: 41 additions & 13 deletions pydeseq2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,11 @@ def dispersion_trend(
return coeffs[0] + coeffs[1] / normed_mean


def nb_nll(counts: np.ndarray, mu: np.ndarray, alpha: float) -> float:
def nb_nll(
counts: np.ndarray, mu: np.ndarray, alpha: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
"""Negative log-likelihood of a negative binomial of parameters ``mu`` and ``alpha``.
Unvectorized version.
Mathematically, if ``counts`` is a vector of counting entries :math:`y_i`
then the likelihood of each entry :math:`y_i` to be drawn from a negative
binomial :math:`NB(\\mu, \\alpha)` is [1]
Expand Down Expand Up @@ -302,13 +302,13 @@ def nb_nll(counts: np.ndarray, mu: np.ndarray, alpha: float) -> float:
mu : ndarray
Mean of the distribution :math:`\\mu`.
alpha : float
alpha : float or ndarray
Dispersion of the distribution :math:`\\alpha`,
s.t. the variance is :math:`\\mu + \\alpha \\mu^2`.
Returns
-------
float
float or ndarray
Negative log likelihood of the observations counts
following :math:`NB(\\mu, \\alpha)`.
Expand All @@ -320,14 +320,22 @@ def nb_nll(counts: np.ndarray, mu: np.ndarray, alpha: float) -> float:
n = len(counts)
alpha_neg1 = 1 / alpha
logbinom = gammaln(counts + alpha_neg1) - gammaln(counts + 1) - gammaln(alpha_neg1)
return (
n * alpha_neg1 * np.log(alpha)
+ (
-logbinom
+ (counts + alpha_neg1) * np.log(alpha_neg1 + mu)
- counts * np.log(mu)
).sum()
)
if hasattr(alpha, "__len__") and len(alpha) > 1:
return (
alpha_neg1 * np.log(alpha)
- logbinom
+ (counts + alpha_neg1) * np.log(mu + alpha_neg1)
- (counts * np.log(mu))
).sum(0)
else:
return (
n * alpha_neg1 * np.log(alpha)
+ (
-logbinom
+ (counts + alpha_neg1) * np.log(alpha_neg1 + mu)
- counts * np.log(mu)
).sum()
)


def dnb_nll(counts: np.ndarray, mu: np.ndarray, alpha: float) -> float:
Expand Down Expand Up @@ -1214,3 +1222,23 @@ def nbinomFn(
).sum(0)

return prior - nll


def mean_absolute_deviation(x: np.ndarray) -> float:
"""
Compute a scaled estimator of the mean absolute deviation.
Used in :meth:`pydeseq2.dds.DeseqDataSet.fit_dispersion_prior()`.
Parameters
----------
x : ndarray
1D array whose MAD to compute.
Returns
-------
float
Mean absolute deviation estimator.
"""
center = np.median(x)
return np.median(np.abs(x - center)) / norm.ppf(0.75)
101 changes: 101 additions & 0 deletions tests/data/single_factor/r_iterative_size_factors.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"","x"
"1",0.915421087077681
"2",1.31102331795221
"3",0.924544213405754
"4",0.801555451850839
"5",0.710694444355712
"6",0.917422160590354
"7",0.961734817171922
"8",1.22268646315809
"9",0.751715285059725
"10",1.04604633383823
"11",0.683500883174763
"12",1.16532627640849
"13",0.939244912412512
"14",1.12075200944084
"15",0.731882201969824
"16",1.05980206820024
"17",0.97077374574086
"18",0.863424019395309
"19",1.22190797224712
"20",1.17200283874692
"21",1.02223686048299
"22",0.798539539602114
"23",0.837701291627756
"24",1.43278452449518
"25",1.10795891591146
"26",1.04680955987751
"27",0.98660267590864
"28",0.994562698605772
"29",0.937632673373463
"30",0.911552273465719
"31",1.13242766493449
"32",1.29763606154318
"33",1.18554898543046
"34",1.03113856840632
"35",1.01051188705132
"36",0.87935239871871
"37",1.47990655974392
"38",1.07541995248229
"39",1.16740559705306
"40",1.0216994869537
"41",0.609344035608844
"42",1.00959522557959
"43",0.966165497570495
"44",1.08582288944723
"45",0.827310816341124
"46",0.811169482064657
"47",0.997085886648129
"48",1.21173821426595
"49",0.947758268767542
"50",0.857615357715104
"51",0.809340649308735
"52",1.44269630031412
"53",0.972576809915263
"54",0.84792976175583
"55",0.692327244926196
"56",1.22285070330531
"57",1.1623542416027
"58",1.64987395544704
"59",1.06587092843215
"60",0.874530415226887
"61",0.930718120979676
"62",1.35768681883417
"63",0.729051760045237
"64",0.924290900430038
"65",0.892732547374792
"66",1.49873680518914
"67",1.01789327922605
"68",1.32116101959647
"69",1.20819418467418
"70",1.23266653080425
"71",0.772147576987224
"72",1.11893535651906
"73",1.03134035796004
"74",1.46823801195759
"75",0.842730736302201
"76",0.815794866580617
"77",1.14640743872935
"78",1.00136232160933
"79",1.08159997028843
"80",1.18245551056573
"81",1.0561566721125
"82",0.852730453186549
"83",0.860000197037354
"84",0.853830751373939
"85",0.58969508089454
"86",1.10089573023988
"87",0.790464773428743
"88",1.37631233809585
"89",0.864637438989214
"90",0.949299623039802
"91",0.863505562118642
"92",1.02614640509524
"93",0.901585518996127
"94",1.43245824373654
"95",1.07374116029894
"96",1.02247346221386
"97",0.743243647020171
"98",0.983103509398641
"99",1.28620165406526
"100",0.978569887650991
Loading

0 comments on commit 67a9f02

Please sign in to comment.