Skip to content

Commit

Permalink
replace squidpy with liana nn in MistyGeneric
Browse files Browse the repository at this point in the history
  • Loading branch information
dbdimitrov committed Jun 11, 2024
1 parent b8e7fcc commit 5139e64
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 40 deletions.
39 changes: 36 additions & 3 deletions docs/source/notebooks/bivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,8 @@
"text": [
"Using `.X`!\n",
"Using resource `consensus`.\n",
"100%|██████████| 100/100 [01:56<00:00, 1.17s/it]\n",
"100%|██████████| 100/100 [01:01<00:00, 1.64it/s]\n"
"100%|██████████| 100/100 [01:06<00:00, 1.51it/s]\n",
"100%|██████████| 100/100 [00:38<00:00, 2.60it/s]\n"
]
}
],
Expand All @@ -651,6 +651,39 @@
"Now that this is done, we can extract and explore the newly-created AnnData object that counts our local scores"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AnnData object with n_obs × n_vars = 4113 × 17703\n",
" obs: 'in_tissue', 'array_row', 'array_col', 'sample', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'mt_frac', 'celltype_niche', 'molecular_niche'\n",
" var: 'gene_ids', 'feature_types', 'genome', 'SYMBOL', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'mt', 'rps', 'mrp', 'rpl', 'duplicated'\n",
" uns: 'spatial', 'log1p', 'celltype_niche_colors'\n",
" obsm: 'compositions', 'mt', 'spatial', 'local_scores'\n",
" layers: 'counts'\n",
" obsp: 'spatial_connectivities'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"adata"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 10,
Expand Down Expand Up @@ -2094,7 +2127,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.8.17"
},
"orig_nbformat": 4
},
Expand Down
3 changes: 2 additions & 1 deletion liana/method/sp/_misty/_Misty.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def __init__(self,
obs:(pd.DataFrame | None)=None,
spatial_key:str=K.spatial_key,
enforce_obs:bool=True,
**kwargs):
**kwargs
):
"""
Construct a MistyData object from a dictionary of views (anndatas).
Expand Down
24 changes: 9 additions & 15 deletions liana/method/sp/_misty/_misty_constructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from liana.resource import select_resource
from liana.method._pipe_utils import prep_check_adata
from liana.method.sp._utils import _add_complexes_to_var
from liana._logging import _check_if_installed

def _make_view(adata, nz_threshold=0.1, add_obs=False, use_raw=False,
layer=None, connecitivity=None, spatial_key=None, verbose=False):
Expand Down Expand Up @@ -57,8 +56,7 @@ def genericMistyData(intra,
cutoff = 0.1,
add_juxta=True,
n_neighs = 6,
verbose=False,
**kwargs,
verbose=False
):

"""
Expand Down Expand Up @@ -96,14 +94,12 @@ def genericMistyData(intra,
cutoff : `float`, optional (default: 0.1)
The cutoff for the connectivity matrix.
add_juxta : `bool`, optional (default: True)
Whether to add the juxtaview. The juxtaview is constructed using `squidpy.gr.spatial_neighbors`,
and should represent the direct spatial neighbors of each cell/spot.
Whether to add the juxtaview. The juxtaview is constructed using only the nearest neighbors.
A bandwidth of 5 times the bandwidth of the paraview is used to ensure that the nearest neighbors within the radius.
n_neighs : `int`, optional (default: 6)
The number of neighbors to consider when constructing the juxtaview.
verbose : `bool`, optional (default: False)
Whether to print progress.
**kwargs : `dict`, optional
Additional arguments to pass to `squidpy.gr.spatial_neighbors`.
Returns
-------
Expand All @@ -121,14 +117,12 @@ def genericMistyData(intra,
extra = intra

if add_juxta:
sq = _check_if_installed('squidpy')
neighbors, _ = sq.gr.spatial_neighbors(adata=extra,
copy=True,
spatial_key=spatial_key,
set_diag=set_diag,
n_neighs=n_neighs,
**kwargs
)
neighbors = spatial_neighbors(extra,
bandwidth=bandwidth*5,
spatial_key=spatial_key,
max_neighbours=n_neighs,
set_diag=set_diag,
inplace=False)

views['juxta'] = _make_view(adata=extra, nz_threshold=nz_threshold,
use_raw=extra_use_raw, layer=extra_layer,
Expand Down
2 changes: 1 addition & 1 deletion liana/tests/test_bivar.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_large_adata():
)
lrdata = adata.obsm['local_scores']
np.testing.assert_almost_equal(lrdata.X.mean(), 0.00048977, decimal=4)
np.testing.assert_almost_equal(lrdata.var['morans'].mean(), 0.00030397394, decimal=4)
np.testing.assert_almost_equal(lrdata.var['morans'].mean(), 0.00012773558, decimal=4)


def test_wrong_interactions():
Expand Down
15 changes: 5 additions & 10 deletions liana/tests/test_misty.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def test_misty_para():
cutoff=0,
add_juxta=False,
set_diag=False,
seed=133
)
misty(model=RandomForestModel, bypass_intra=False, seed=42, n_estimators=11)
assert np.isin(list(misty.uns.keys()), ['target_metrics', 'interactions']).all()
Expand All @@ -37,9 +36,7 @@ def test_misty_bypass():
bandwidth=10,
add_juxta=True,
set_diag=True,
cutoff=0,
coord_type="generic",
delaunay=True)
cutoff=0)
misty(model=RandomForestModel, alphas=1, bypass_intra=True, seed=42, n_estimators=11)
assert np.isin(['juxta', 'para'], misty.uns['target_metrics'].columns).all()
assert ~np.isin(['intra'], misty.uns['target_metrics'].columns).all()
Expand All @@ -51,7 +48,7 @@ def test_misty_bypass():
assert interactions['importances'].sum().round(10) == 22.0
np.testing.assert_almost_equal(interactions[(interactions['target']=='ligC') &
(interactions['predictor']=='ligA')]['importances'].values,
np.array([0.0444664, 0.0551506]), decimal=3)
np.array([0.095, 0.07]), decimal=3)


def test_misty_groups():
Expand All @@ -60,8 +57,6 @@ def test_misty_groups():
add_juxta=True,
set_diag=False,
cutoff=0,
coord_type="generic",
delaunay=True
)
misty(model=RandomForestModel,
alphas=1,
Expand All @@ -82,7 +77,7 @@ def test_misty_groups():
# assert that there are self interactions = var_n * var_n
interactions = misty.uns['interactions']
self_interactions = interactions[(interactions['target']==interactions['predictor'])]
# 11 vars * 4 envs * 3 views = 132; NOTE: However, I drop NAs -> to be refactored...
# 11 vars * 4 envs * 3 views = 132; NOTE: However, I drop NAs
assert self_interactions.shape == (44, 5)
assert self_interactions[self_interactions['view']=='intra']['importances'].isna().all()

Expand Down Expand Up @@ -110,7 +105,7 @@ def test_linear_misty():

assert misty.uns['interactions'].shape == (330, 4)
actual = misty.uns['interactions']['importances'].values.mean()
np.testing.assert_almost_equal(actual, 0.4941761900911731, decimal=3)
np.testing.assert_almost_equal(actual, 0.5135328101662447, decimal=3)


def test_misty_mask():
Expand All @@ -126,7 +121,7 @@ def test_misty_mask():
np.testing.assert_almost_equal(misty.uns['target_metrics']['intra_R2'].mean(), 0.4248588250759459, decimal=3)

assert misty.uns['interactions'].shape == (330, 4)
np.testing.assert_almost_equal(misty.uns['interactions']['importances'].sum(), 141.05332654128952, decimal=0)
np.testing.assert_almost_equal(misty.uns['interactions']['importances'].sum(), 149.30560405771703, decimal=0)


def test_misty_custom():
Expand Down
9 changes: 4 additions & 5 deletions liana/utils/query_bandwidth.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy as np
from sklearn.neighbors import BallTree
from plotnine import ggplot, aes, geom_line, geom_point, theme_bw, xlab, ylab, scale_y_continuous
from plotnine import ggplot, aes, geom_line, geom_point, theme_bw, xlab, ylab
from pandas import DataFrame

def query_bandwidth(coordinates: np.ndarray,
start: int = 0,
end: int = 500,
interval_n:int = 50,
interval_n: int = 50,
reference: np.ndarray = None
):
"""
Expand Down Expand Up @@ -49,13 +49,12 @@ def query_bandwidth(coordinates: np.ndarray,
num_neighbors = tree.query_radius(_reference, r=max_distance, count_only=True)

# calculate the average number of neighbors
avg_nn = np.mean(num_neighbors)
df.loc[n, 'neighbours'] = avg_nn
avg_nn = np.ceil(np.median(num_neighbors))
df.loc[n, 'neighbours'] = avg_nn - 1

p = (ggplot(df, aes(x='bandwith', y='neighbours')) +
geom_line() +
geom_point() +
scale_y_continuous(breaks=range(start, end, interval_n)) +
theme_bw(base_size=16) +
xlab("Bandwidth") +
ylab("Number of Neighbors")
Expand Down
10 changes: 5 additions & 5 deletions liana/utils/spatial_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _linear(distance_mtx, bandwidth):
@d.dedent
def spatial_neighbors(adata: AnnData,
bandwidth=None,
cutoff=None,
cutoff=0.1,
max_neighbours=None,
kernel='gaussian',
set_diag=False,
Expand Down Expand Up @@ -103,9 +103,9 @@ def spatial_neighbors(adata: AnnData,
if max_neighbours is None:
max_neighbours = int(adata.shape[0] / 10)

tree = NearestNeighbors(n_neighbors=max_neighbours,
algorithm='ball_tree',
metric='euclidean').fit(_reference)
tree = NearestNeighbors(n_neighbors=max_neighbours + 1, # +1 to exclude self
algorithm='ball_tree',
metric='euclidean').fit(_reference)
dist = tree.kneighbors_graph(coordinates, mode='distance')

# prevent float overflow
Expand All @@ -114,7 +114,7 @@ def spatial_neighbors(adata: AnnData,
# define zone of indifference
dist.data[dist.data < zoi] = np.inf

# NOTE: dist gets converted to a connectivity matrix
# NOTE: dist gets converted to a connectivity (proximity) matrix
if kernel == 'gaussian':
dist.data = _gaussian(dist.data, bandwidth)
elif kernel == 'misty_rbf':
Expand Down

0 comments on commit 5139e64

Please sign in to comment.