Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skeleton for LR encoding ED model with extension of estimator #44

Merged
merged 22 commits into from
Nov 7, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
f99ff95
skeleton for LR encoding ED model with extension of estimator
davidsebfischer Oct 12, 2021
f84f8a7
enabled extraction of scaled adj matrix
davidsebfischer Oct 12, 2021
162ba4e
debugged unit tests
davidsebfischer Oct 13, 2021
37ec983
added disclaimer
davidsebfischer Oct 13, 2021
64ae03e
Merge branch 'feature/lr_encoder' of github.com:theislab/ncem into fe…
AnnaChristina Oct 14, 2021
3dc14a3
simplify output layers with get_out in remaining models
AnnaChristina Oct 14, 2021
427ffb5
simplify output layers with get_out in remaining models
AnnaChristina Oct 14, 2021
833ec07
add disclaimer to cond layers
AnnaChristina Oct 14, 2021
21f54c9
fix max and gcn layer for single gnn
AnnaChristina Oct 19, 2021
998c8cc
Bump version from 0.3.2 to 0.4.0
AnnaChristina Oct 25, 2021
01438aa
added node embedding and output weight saving in EDncem models
davidsebfischer Oct 28, 2021
7405baa
Merge pull request #53 from theislab/development
AnnaChristina Nov 3, 2021
e26132a
fix conflicts
AnnaChristina Nov 3, 2021
9209a78
Merge branch 'release' of github.com:theislab/ncem into release
AnnaChristina Nov 3, 2021
e61d30c
Merge pull request #54 from theislab/release
AnnaChristina Nov 3, 2021
0c2839e
Bump version from 0.4.6 to 0.4.7
AnnaChristina Nov 3, 2021
0e2711f
add n_top_genes to dataloader
AnnaChristina Nov 4, 2021
2f03c77
Bump version from 0.4.7 to 0.4.8
AnnaChristina Nov 4, 2021
be89c0b
add hgnc names for schuerch
AnnaChristina Nov 7, 2021
803063d
Merge branch 'feature/embedding_saving' of github.com:theislab/ncem i…
AnnaChristina Nov 7, 2021
8a4bcfa
add saving of LR names to model class
AnnaChristina Nov 7, 2021
bd2f900
Merge pull request #51 from theislab/feature/embedding_saving
AnnaChristina Nov 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
skeleton for LR encoding ED model with extension of estimator
added estimators that yield neighbrouhood tensors of features and avoid using the full adjacency matrix, this can be used for single layer embeddings and breaks the scaling of attention coefficient computation with number of nodes in graph
  • Loading branch information
davidsebfischer committed Oct 12, 2021
commit f99ff95346d8ccb0fe50b09bd4609e890125ca7c
3 changes: 2 additions & 1 deletion ncem/estimators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Importing estimator classes."""
from ncem.estimators.base_estimator import (Estimator, EstimatorGraph,
EstimatorNoGraph)
from ncem.estimators.base_estimator_neighbors import EstimatorNeighborhood
from ncem.estimators.estimator_cvae import EstimatorCVAE
from ncem.estimators.estimator_cvae_ncem import EstimatorCVAEncem
from ncem.estimators.estimator_ed import EstimatorED
from ncem.estimators.estimator_ed_ncem import EstimatorEDncem
from ncem.estimators.estimator_ed_ncem import EstimatorEDncem, EstimatorEdNcemNeighborhood
from ncem.estimators.estimator_interactions import EstimatorInteractions
from ncem.estimators.estimator_linear import EstimatorLinear
220 changes: 220 additions & 0 deletions ncem/estimators/base_estimator_neighbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import tensorflow as tf

from ncem.estimators.base_estimator import Estimator


class EstimatorNeighborhood(Estimator):
"""EstimatorGraph class for spatial models of the nieghborhood only (not full graph)."""

n_features_in: int
_n_neighbors_padded: int
h0_in: bool
idx_target_features: np.ndarray
idx_neighbor_features: np.ndarray

def set_input_features(self, h0_in=True, target_feature_names=None, neighbor_feature_names=None):
"""
Need to run this before compiling model.

Returns:
"""
self.h0_in = h0_in
if self.h0_in:
assert target_feature_names is None
assert neighbor_feature_names is None
self.n_features_in = self.n_features_0
else:
self.idx_target_features = None # TODO match names to feature names in h1 here, as np index array
self.idx_neighbor_features = None # TODO match names to feature names in h1 here, as np index array
assert len(self.idx_target_features.tolist()) == len(self.idx_neighbor_features.tolist())
assert len(set(self.idx_target_features.tolist()).intersection(set(self.idx_neighbor_features.tolist()))) == 0
self.n_features_in = len(self.idx_target_features)

@property
def n_neighbors_padded(self):
if self._n_neighbors_padded is None:
self._n_neighbors_padded = np.max(np.asarray([
np.max(np.asarray(np.sum(a, axis=1)).flatten()) for a in self.a.values()
]))
return self._n_neighbors_padded

def _get_output_signature(self, resampled: bool = False):
"""Get output signatures.

Parameters
----------
resampled : bool
Whether dataset is resampled or not.

Returns
-------
output_signature
"""
h_targets = tf.TensorSpec(
shape=(self.n_eval_nodes_per_graph, self.n_features_in), dtype=tf.float32
) # target node features
h_neighbors = tf.TensorSpec(
shape=(self.n_neighbors_padded, self.n_features_in), dtype=tf.float32
) # neighbor node features
sf = tf.TensorSpec(shape=(self.n_eval_nodes_per_graph, 1), dtype=tf.float32) # input node size factors
node_covar = tf.TensorSpec(
shape=(self.n_eval_nodes_per_graph, self.n_node_covariates), dtype=tf.float32
) # node-level covariates
a = tf.TensorSpec(
shape=(self.n_eval_nodes_per_graph, self.n_neighbors_padded), dtype=tf.float32
) # adjacency matrix
domain = tf.TensorSpec(shape=(self.n_domains,), dtype=tf.int32) # domain
reconstruction = tf.TensorSpec(
shape=(self.n_eval_nodes_per_graph, self.n_features_1), dtype=tf.float32
) # node features to reconstruct
kl_dummy = tf.TensorSpec(shape=(self.n_eval_nodes_per_graph,), dtype=tf.float32) # dummy for kl loss

if self.vi_model:
if resampled:
output_signature = (
(h_targets, h_neighbors, sf, a, node_covar, domain),
(reconstruction, kl_dummy),
(h_targets, h_neighbors, sf, a, node_covar, domain),
(reconstruction, kl_dummy),
)
else:
output_signature = ((h_targets, h_neighbors, sf, a, node_covar, domain),
(reconstruction, kl_dummy))
else:
if resampled:
output_signature = (
(h_targets, h_neighbors, sf, a, node_covar, domain),
reconstruction,
(h_targets, h_neighbors, sf, a, node_covar, domain),
reconstruction,
)
else:
output_signature = ((h_targets, h_neighbors, sf, a, node_covar, domain),
reconstruction)
return output_signature

def _get_dataset_base(
self,
image_keys: List[str],
nodes_idx: Dict[str, np.ndarray],
batch_size: int,
shuffle_buffer_size: Optional[int],
train: bool = True,
seed: Optional[int] = None,
prefetch: int = 100,
reinit_n_eval: Optional[int] = None,
):
"""Prepare a dataset.

Parameters
----------
image_keys : np.array
Image keys in partition.
nodes_idx : dict, str
Dictionary of nodes per image in partition.
batch_size : int
Batch size.
shuffle_buffer_size : int, optional
Shuffle buffer size.
train : bool
Whether dataset is used for training or not (influences shuffling of nodes).
seed : int, optional
Random seed.
prefetch: int
Prefetch of dataset.
reinit_n_eval : int, optional
Used if model is reinitialized to different number of nodes per graph.

Returns
-------
A tensorflow dataset.
"""
np.random.seed(seed)
if reinit_n_eval is not None and reinit_n_eval != self.n_eval_nodes_per_graph:
print(
"ATTENTION: specifying reinit_n_eval will change class argument n_eval_nodes_per_graph "
"from %i to %i" % (self.n_eval_nodes_per_graph, reinit_n_eval)
)
self.n_eval_nodes_per_graph = reinit_n_eval

def generator():
for key in image_keys:
if nodes_idx[key].size == 0: # needed for images where no nodes are selected
continue
idx_nodes = np.arange(0, self.a[key].shape[0])

if train:
index_list = [
np.asarray(
np.random.choice(
a=nodes_idx[key],
size=self.n_eval_nodes_per_graph,
replace=True,
),
dtype=np.int32,
)
]
else:
# dropping
index_list = [
np.asarray(
nodes_idx[key][self.n_eval_nodes_per_graph * i: self.n_eval_nodes_per_graph * (i + 1)],
dtype=np.int32,
)
for i in range(len(nodes_idx[key]) // self.n_eval_nodes_per_graph)
]

for indices in index_list:
h_out = self.h_1[key][idx_nodes[indices], :]
if self.h0_in:
h_targets = self.h_0[key][idx_nodes[indices], :]
else:
h_targets = self.h_1[key][idx_nodes[indices], self.idx_target_features]
h_neighbors = []
a_neighborhood = np.zeros((self.n_eval_nodes_per_graph, self.n_neighbors_padded), "float32")
for i, j in enumerate(idx_nodes[indices]):
idx_neighbors = np.where(np.asarray(self.a[key][j, :].todense()).flatten() == 1.)[0]
if self.h0_in:
h_neighbors_j = self.h_0[key][idx_neighbors, :]
else:
h_neighbors_j = self.h_1[key][idx_neighbors, self.idx_neighbor_features]
h_neighbors_j = np.expand_dims(h_neighbors_j, axis=0)
# Pad neighborhoods:
diff = self.n_neighbors_padded - h_neighbors_j.shape[1]
zeros = np.zeros((1, diff, h_neighbors_j.shape[2]), dtype="float32")
h_neighbors_j = np.concatenate([h_neighbors_j, zeros], axis=1)
h_neighbors.append(h_neighbors_j)
a_neighborhood[i, :len(idx_neighbors)] = 1.
h_neighbors = np.concatenate([h_neighbors], axis=0)
if self.log_transform:
h_targets = np.log(h_targets + 1.0)
h_neighbors = np.log(h_neighbors + 1.0)

node_covar = self.node_covar[key][idx_nodes]
node_covar = node_covar[indices, :]

sf = np.expand_dims(self.size_factors[key][idx_nodes], axis=1)
sf = sf[indices, :]

g = np.zeros((self.n_domains,), dtype="int32")
g[self.domains[key]] = 1

if self.vi_model:
kl_dummy = np.zeros((self.n_eval_nodes_per_graph,), dtype="float32")
yield (h_targets, h_neighbors, sf, a_neighborhood, node_covar, g), (h_out, kl_dummy)
else:
yield (h_targets, h_neighbors, sf, a_neighborhood, node_covar, g), h_out

output_signature = self._get_output_signature(resampled=False)

dataset = tf.data.Dataset.from_generator(generator=generator, output_signature=output_signature)
if train:
if shuffle_buffer_size is not None:
dataset = dataset.shuffle(buffer_size=shuffle_buffer_size, seed=None, reshuffle_each_iteration=True)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(prefetch)
return dataset
92 changes: 90 additions & 2 deletions ncem/estimators/estimator_ed_ncem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import tensorflow as tf
from typing import Tuple

from ncem.estimators import EstimatorGraph
from ncem.models import ModelEDncem
from ncem.estimators import EstimatorGraph, EstimatorNeighborhood
from ncem.models import ModelEDncem, ModelEd2Ncem


class EstimatorEDncem(EstimatorGraph):
Expand Down Expand Up @@ -161,3 +162,90 @@ def init_model(
self.max_beta = max_beta
self.pre_warm_up = pre_warm_up
self._compile_model(optimizer=optimizer, output_layer=output_layer)


class EstimatorEdNcemNeighborhood(EstimatorNeighborhood):
"""Estimator class for encoder-decoder NCEM models with single graph layer. Subclass of EstimatorNeighborhood."""

def __init__(
self,
cond_type: str,
use_type_cond: bool,
log_transform: bool = False,
):
"""Initialize a EstimatorEDncem object.

Parameters
----------
cond_type : str
Max, ind or gcn, graph layer used in conditional.
use_type_cond : bool
Whether to use the categorical cell type label in conditional.
log_transform : bool
Whether to log transform h_1.

Raises
------
ValueError
If `cond_type` is not recognized.
"""
super(EstimatorEdNcemNeighborhood, self).__init__()
self.model_type = "ed_ncem"
if cond_type in ["gat", "lr_gat"]:
self.adj_type = "full"
else:
raise ValueError("cond_type %s not recognized" % cond_type)
self.cond_type = cond_type
self.use_type_cond = use_type_cond
self.log_transform = log_transform
self.metrics = {"np": [], "tf": []}
self.n_eval_nodes_per_graph = None

def init_model(
self,
optimizer: str,
learning_rate: float,
latent_dim: Tuple[int],
dropout_rate: float,
l2_coef: float,
l1_coef: float,
n_eval_nodes_per_graph: int,
use_domain: bool,
scale_node_size: bool,
output_layer: str,
dec_intermediate_dim: int,
dec_n_hidden: int,
dec_dropout_rate: float,
dec_l1_coef: float,
dec_l2_coef: float,
dec_use_batch_norm: bool,
**kwargs
):
self.n_eval_nodes_per_graph = n_eval_nodes_per_graph
self.model = ModelEd2Ncem(
input_shapes=(
self.n_features_in,
self.n_features_1,
self.n_eval_nodes_per_graph,
self.n_neighbors_padded,
self.n_node_covariates,
self.n_domains,
),
latent_dim=latent_dim,
dropout_rate=dropout_rate,
l2_coef=l2_coef,
l1_coef=l1_coef,
use_domain=use_domain,
use_type_cond=self.use_type_cond,
scale_node_size=scale_node_size,
output_layer=output_layer,
dec_intermediate_dim=dec_intermediate_dim,
dec_n_hidden=dec_n_hidden,
dec_dropout_rate=dec_dropout_rate,
dec_l1_coef=dec_l1_coef,
dec_l2_coef=dec_l2_coef,
dec_use_batch_norm=dec_use_batch_norm,
)
optimizer = tf.keras.optimizers.get(optimizer)
tf.keras.backend.set_value(optimizer.lr, learning_rate)
self._compile_model(optimizer=optimizer, output_layer=output_layer)
1 change: 1 addition & 0 deletions ncem/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from ncem.models.model_cvae_ncem import ModelCVAEncem
from ncem.models.model_ed import ModelED
from ncem.models.model_ed_ncem import ModelEDncem
from ncem.models.model_ed_single_ncem import ModelEd2Ncem
from ncem.models.model_interactions import ModelInteractions
from ncem.models.model_linear import ModelLinear
4 changes: 3 additions & 1 deletion ncem/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@
LinearOutput,
NegBinConstDispOutput,
NegBinOutput,
NegBinSharedDispOutput)
NegBinSharedDispOutput,
get_out)
from ncem.models.layers.preproc_input import DenseInteractions, PreprocInput
from ncem.models.layers.single_gnn_layers import SingleLrGatLayer, SingleGatLayer
30 changes: 30 additions & 0 deletions ncem/models/layers/output_layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,36 @@
import tensorflow as tf


def get_out(output_layer: str, out_feature_dim, scale_node_size):
if output_layer == "gaussian":
output_decoder_layer = GaussianOutput(
original_dim=out_feature_dim,
use_node_scale=scale_node_size,
name="GaussianOutput_decoder",
)
elif output_layer == "nb":
output_decoder_layer = NegBinOutput(
original_dim=out_feature_dim,
use_node_scale=scale_node_size,
name="NegBinOutput_decoder",
)
elif output_layer == "nb_shared_disp":
output_decoder_layer = NegBinSharedDispOutput(
original_dim=out_feature_dim,
use_node_scale=scale_node_size,
name="NegBinSharedDispOutput_decoder",
)
elif output_layer == "nb_const_disp":
output_decoder_layer = NegBinConstDispOutput(
original_dim=out_feature_dim,
use_node_scale=scale_node_size,
name="NegBinConstDispOutput_decoder",
)
else:
raise ValueError("tried to access a non-supported output layer %s" % output_layer)
return output_decoder_layer


class LinearOutput(tf.keras.layers.Layer):
"""Linear output layer."""

Expand Down
Loading