Skip to content

Commit

Permalink
pp.highly_variable_genes handles subset and inplace consistently (scv…
Browse files Browse the repository at this point in the history
…erse#2757)

Co-authored-by: Philipp A <[email protected]>
  • Loading branch information
eroell and flying-sheep authored Dec 4, 2023
1 parent bc349b9 commit 822b151
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1.9.7.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
- Fix handling of numpy array palettes (e.g. after write-read cycle) {pr}`2734` {smaller}`P Angerer`
- Specify correct version of `matplotlib` dependency {pr}`2733` {smaller}`P Fisher`
- Fix {func}`scanpy.pl.violin` usage of `seaborn.catplot` {pr}`2739` {smaller}`E Roellin`
- Fix {func}`scanpy.pp.highly_variable_genes` to handle the combinations of `inplace` and `subset` consistently {pr}`2757` {smaller}`E Roellin`
12 changes: 10 additions & 2 deletions scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _highly_variable_genes_seurat_v3(
df["highly_variable"] = False
df.loc[sorted_index[: int(n_top_genes)], "highly_variable"] = True

if inplace or subset:
if inplace:
adata.uns["hvg"] = {"flavor": "seurat_v3"}
logg.hint(
"added\n"
Expand All @@ -172,6 +172,9 @@ def _highly_variable_genes_seurat_v3(
else:
if batch_key is None:
df = df.drop(["highly_variable_nbatches"], axis=1)
if subset:
df = df.iloc[df.highly_variable.values, :]

return df


Expand Down Expand Up @@ -544,7 +547,7 @@ def highly_variable_genes(

logg.info(" finished", time=start)

if inplace or subset:
if inplace:
adata.uns["hvg"] = {"flavor": flavor}
logg.hint(
"added\n"
Expand All @@ -559,6 +562,7 @@ def highly_variable_genes(
adata.var["dispersions_norm"] = df["dispersions_norm"].values.astype(
"float32", copy=False
)

if batch_key is not None:
adata.var["highly_variable_nbatches"] = df[
"highly_variable_nbatches"
Expand All @@ -568,5 +572,9 @@ def highly_variable_genes(
].values
if subset:
adata._inplace_subset_var(df["highly_variable"].values)

else:
if subset:
df = df.iloc[df.highly_variable.values, :]

return df
35 changes: 35 additions & 0 deletions scanpy/tests/test_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,38 @@ def test_cellranger_n_top_genes_warning():
match="`n_top_genes` > number of normalized dispersions, returning all genes with normalized dispersions.",
):
sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger")


@pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"])
@pytest.mark.parametrize("subset", [True, False])
@pytest.mark.parametrize("inplace", [True, False])
def test_highly_variable_genes_subset_inplace_consistency(
flavor,
subset,
inplace,
):
adata = sc.datasets.blobs(n_observations=20, n_variables=80, random_state=0)
adata.X = np.abs(adata.X).astype(int)

if flavor == "seurat" or flavor == "cell_ranger":
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

elif flavor == "seurat_v3":
pass

else:
raise ValueError(f"Unknown flavor {flavor}")

n_genes = adata.shape[1]

output_df = sc.pp.highly_variable_genes(
adata,
flavor=flavor,
n_top_genes=15,
subset=subset,
inplace=inplace,
)

assert (output_df is None) == inplace
assert len(adata.var if inplace else output_df) == (15 if subset else n_genes)

0 comments on commit 822b151

Please sign in to comment.