Skip to content

Commit

Permalink
attempt to improve node info subclass typing
Browse files Browse the repository at this point in the history
method from #26 (comment)
  • Loading branch information
dhimmel committed Jul 12, 2023
1 parent b7083ef commit fa1d381
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 58 deletions.
12 changes: 6 additions & 6 deletions nxontology/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from nxontology import NXOntology
from nxontology.exceptions import NodeNotFound
from nxontology.node import Node
from nxontology.node import NodeT

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,21 +86,21 @@ def from_file(handle: BinaryIO | str | PathLike[AnyStr]) -> NXOntology[str]:

def _pronto_edges_for_term(
term: Term, default_rel_type: str = "is a"
) -> list[tuple[Node, Node, str]]:
) -> list[tuple[NodeT, NodeT, str]]:
"""
Extract edges including "is a" relationships for a Pronto term.
https://github.com/althonos/pronto/issues/119#issuecomment-956541286
"""
rels = []
source_id = cast(Node, term.id)
source_id = cast(NodeT, term.id)
for target in term.superclasses(distance=1, with_self=False):
rels.append((source_id, cast(Node, target.id), default_rel_type))
rels.append((source_id, cast(NodeT, target.id), default_rel_type))
for rel_type, targets in term.relationships.items():
for target in sorted(targets):
rels.append(
(
cast(Node, term.id),
cast(Node, target.id),
cast(NodeT, term.id),
cast(NodeT, target.id),
rel_type.name or rel_type.id,
)
)
Expand Down
25 changes: 13 additions & 12 deletions nxontology/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

# Type definitions. networkx does not declare types.
# https://github.com/networkx/networkx/issues/3988#issuecomment-639969263
Node = TypeVar("Node", bound=Hashable)
NodeT = TypeVar("NodeT", bound=Hashable)
NodeInfoT = TypeVar("NodeInfoT")


class Node_Info(Freezable, Generic[Node]):
class Node_Info(Freezable, Generic[NodeT]):
"""
Compute metrics and values for a node of an NXOntology.
Includes intrinsic information content (IC) metrics.
Expand All @@ -35,7 +36,7 @@ class Node_Info(Freezable, Generic[Node]):
Each ic_metric has a scaled version accessible by adding a _scaled suffix.
"""

def __init__(self, nxo: NXOntology[Node], node: Node):
def __init__(self, nxo: NXOntology[NodeT, NodeInfoT], node: NodeT):
if node not in nxo.graph:
raise NodeNotFound(f"{node} not in graph.")
self.nxo = nxo
Expand Down Expand Up @@ -98,12 +99,12 @@ def data(self) -> dict[Any, Any]:
return data

@property
def parents(self) -> set[Node]:
def parents(self) -> set[NodeT]:
"""Direct parent nodes of this node."""
return set(self.nxo.graph.predecessors(self.node))

@property
def parent(self) -> Node | None:
def parent(self) -> NodeT | None:
"""
Sole parent of this node, or None if this node is a root.
If this node has multiple parents, raise ValueError.
Expand All @@ -118,13 +119,13 @@ def parent(self) -> Node | None:
raise ValueError(f"Node {self!r} has multiple parents.")

@property
def children(self) -> set[Node]:
def children(self) -> set[NodeT]:
"""Direct child nodes of this node."""
return set(self.nxo.graph.successors(self.node))

@property
@cache_on_frozen
def ancestors(self) -> set[Node]:
def ancestors(self) -> set[NodeT]:
"""
Get ancestors of node in graph, including the node itself.
Ancestors refers to more general concepts in an ontology,
Expand All @@ -137,7 +138,7 @@ def ancestors(self) -> set[Node]:

@property
@cache_on_frozen
def descendants(self) -> set[Node]:
def descendants(self) -> set[NodeT]:
"""
Get descendants of node in graph, including the node itself.
Descendants refers to more specific concepts in an ontology,
Expand All @@ -160,12 +161,12 @@ def n_descendants(self) -> int:

@property
@cache_on_frozen
def roots(self) -> set[Node]:
def roots(self) -> set[NodeT]:
"""Ancestors of this node that are roots (top-level)."""
return self.ancestors & self.nxo.roots

@property
def leaves(self) -> set[Node]:
def leaves(self) -> set[NodeT]:
"""Descendents of this node that are leaves."""
return self.descendants & self.nxo.leaves

Expand All @@ -181,14 +182,14 @@ def depth(self) -> int:
return depth

@property
def paths_from_roots(self) -> Iterator[list[Node]]:
def paths_from_roots(self) -> Iterator[list[NodeT]]:
for root in self.roots:
yield from nx.all_simple_paths(
self.nxo.graph, source=root, target=self.node
)

@property
def paths_to_leaves(self) -> Iterator[list[Node]]:
def paths_to_leaves(self) -> Iterator[list[NodeT]]:
yield from nx.all_simple_paths(
self.nxo.graph, source=self.node, target=self.leaves
)
Expand Down
54 changes: 31 additions & 23 deletions nxontology/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import json
import logging
from abc import abstractmethod
from os import PathLike, fspath
from typing import Any, Generic, Iterable, cast

Expand All @@ -12,7 +13,7 @@
from networkx.algorithms.isolate import isolates
from networkx.readwrite.json_graph import node_link_data, node_link_graph

from nxontology.node import Node
from nxontology.node import NodeInfoT, NodeT

from .exceptions import DuplicateError, NodeNotFound
from .node import Node_Info
Expand All @@ -22,7 +23,7 @@
logger = logging.getLogger(__name__)


class NXOntology(Freezable, Generic[Node]):
class NXOntologyBase(Freezable, Generic[NodeT, NodeInfoT]):
"""
Encapsulate a networkx.DiGraph to represent an ontology.
Regarding edge directionality, parent terms should point to child term.
Expand All @@ -39,7 +40,7 @@ def __init__(
# in case there are compatability issues in the future.
self._add_nxontology_metadata()
self.check_is_dag()
self._node_info_cache: dict[Node, Node_Info[Node]] = {}
self._node_info_cache: dict[NodeT, Node_Info[NodeT]] = {}

def _add_nxontology_metadata(self) -> None:
self.graph.graph["nxontology_version"] = get_nxontology_version()
Expand Down Expand Up @@ -77,7 +78,7 @@ def write_node_link_json(self, path: str | PathLike[str]) -> None:
write_file.write("\n") # json.dump does not include a trailing newline

@classmethod
def read_node_link_json(cls, path: str | PathLike[str]) -> NXOntology[Node]:
def read_node_link_json(cls, path: str | PathLike[str]) -> NXOntologyBase[NodeT]:
"""
Retrun a new graph from node-link format as written by `write_node_link_json`.
"""
Expand All @@ -90,7 +91,7 @@ def read_node_link_json(cls, path: str | PathLike[str]) -> NXOntology[Node]:
nxo = cls(digraph)
return nxo

def add_node(self, node_for_adding: Node, **attr: Any) -> None:
def add_node(self, node_for_adding: NodeT, **attr: Any) -> None:
"""
Like networkx.DiGraph.add_node but raises a DuplicateError
if the node already exists.
Expand All @@ -99,7 +100,7 @@ def add_node(self, node_for_adding: Node, **attr: Any) -> None:
raise DuplicateError(f"node already in graph: {node_for_adding}")
self.graph.add_node(node_for_adding, **attr)

def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr: Any) -> None:
def add_edge(self, u_of_edge: NodeT, v_of_edge: NodeT, **attr: Any) -> None:
"""
Like networkx.DiGraph.add_edge but
raises a NodeNotFound if either node does not exist
Expand All @@ -116,7 +117,7 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr: Any) -> None:

@property
@cache_on_frozen
def roots(self) -> set[Node]:
def roots(self) -> set[NodeT]:
"""
Return all top-level nodes, including isolates.
"""
Expand All @@ -127,7 +128,7 @@ def roots(self) -> set[Node]:
return roots

@property
def root(self) -> Node:
def root(self) -> NodeT:
"""
Sole root of this directed acyclic graph.
If this ontology has multiple roots, raise ValueError.
Expand All @@ -142,7 +143,7 @@ def root(self) -> Node:

@property
@cache_on_frozen
def leaves(self) -> set[Node]:
def leaves(self) -> set[NodeT]:
"""
Return all bottom-level nodes, including isolates.
"""
Expand All @@ -154,7 +155,7 @@ def leaves(self) -> set[Node]:

@property
@cache_on_frozen
def isolates(self) -> set[Node]:
def isolates(self) -> set[NodeT]:
"""
Return disconnected nodes.
"""
Expand All @@ -175,17 +176,17 @@ def frozen(self) -> bool:

def similarity(
self,
node_0: Node,
node_1: Node,
node_0: NodeT,
node_1: NodeT,
ic_metric: str = "intrinsic_ic_sanchez",
) -> SimilarityIC[Node]:
) -> SimilarityIC[NodeT]:
"""SimilarityIC instance for the specified nodes"""
return SimilarityIC(self, node_0, node_1, ic_metric)

def similarity_metrics(
self,
node_0: Node,
node_1: Node,
node_0: NodeT,
node_1: NodeT,
ic_metric: str = "intrinsic_ic_sanchez",
keys: list[str] | None = None,
) -> dict[str, Any]:
Expand All @@ -197,8 +198,8 @@ def similarity_metrics(

def compute_similarities(
self,
source_nodes: Iterable[Node],
target_nodes: Iterable[Node],
source_nodes: Iterable[NodeT],
target_nodes: Iterable[NodeT],
ic_metrics: list[str] | tuple[str, ...] = ("intrinsic_ic_sanchez",),
) -> Iterable[dict[str, Any]]:
"""
Expand All @@ -213,16 +214,17 @@ def compute_similarities(
yield metrics

@classmethod
def _get_node_info_cls(cls) -> type[Node_Info[Node]]:
@abstractmethod
def _get_node_info_cls(cls) -> type[NodeInfoT]:
"""
Return the Node_Info class to use for this ontology.
Subclasses can override this to use a custom Node_Info class.
For the complexity of typing this method, see
<https://github.com/related-sciences/nxontology/pull/26>.
"""
return Node_Info
...

def node_info(self, node: Node) -> Node_Info[Node]:
def node_info(self, node: NodeT) -> NodeInfoT:
"""
Return Node_Info instance for `node`.
If frozen, cache node info in `self._node_info_cache`.
Expand All @@ -235,8 +237,8 @@ def node_info(self, node: Node) -> Node_Info[Node]:
return self._node_info_cache[node]

@cache_on_frozen
def _get_name_to_node_info(self) -> dict[str, Node_Info[Node]]:
name_to_node_info: dict[str, Node_Info[Node]] = {}
def _get_name_to_node_info(self) -> dict[str, Node_Info[NodeT]]:
name_to_node_info: dict[str, Node_Info[NodeT]] = {}
for node in self.graph:
info = self.node_info(node)
name = info.name
Expand All @@ -249,7 +251,7 @@ def _get_name_to_node_info(self) -> dict[str, Node_Info[Node]]:
name_to_node_info[name] = info
return name_to_node_info

def node_info_by_name(self, name: str) -> Node_Info[Node]:
def node_info_by_name(self, name: str) -> Node_Info[NodeT]:
"""
Return Node_Info instance using a lookup by name.
"""
Expand Down Expand Up @@ -306,3 +308,9 @@ def set_graph_attributes(
self.graph.graph["node_identifier_attribute"] = node_identifier_attribute
if node_url_attribute:
self.graph.graph["node_url_attribute"] = node_url_attribute


class NXOntology(NXOntologyBase[NodeT, Node_Info[NodeT]]):
@classmethod
def _get_node_info_cls(cls) -> type[Node_Info[NodeT]]:
return Node_Info
Loading

0 comments on commit fa1d381

Please sign in to comment.