Skip to content

Commit

Permalink
attempt to improve node info subclass typing
Browse files Browse the repository at this point in the history
method from #26 (comment)
  • Loading branch information
dhimmel committed Jul 12, 2023
1 parent b68411d commit f6134ac
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
3 changes: 2 additions & 1 deletion nxontology/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Type definitions. networkx does not declare types.
# https://github.com/networkx/networkx/issues/3988#issuecomment-639969263
NodeT = TypeVar("NodeT", bound=Hashable)
NodeInfoT = TypeVar("NodeInfoT")


class NodeInfo(Freezable, Generic[NodeT]):
Expand All @@ -35,7 +36,7 @@ class NodeInfo(Freezable, Generic[NodeT]):
Each ic_metric has a scaled version accessible by adding a _scaled suffix.
"""

def __init__(self, nxo: NXOntology[NodeT], node: NodeT):
def __init__(self, nxo: NXOntology[NodeT, NodeInfoT], node: NodeT):
if node not in nxo.graph:
raise NodeNotFound(f"{node} not in graph.")
self.nxo = nxo
Expand Down
18 changes: 13 additions & 5 deletions nxontology/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import json
import logging
from abc import abstractmethod
from os import PathLike, fspath
from typing import Any, Generic, Iterable, cast

Expand All @@ -12,7 +13,7 @@
from networkx.algorithms.isolate import isolates
from networkx.readwrite.json_graph import node_link_data, node_link_graph

from nxontology.node import NodeT
from nxontology.node import NodeInfoT, NodeT

from .exceptions import DuplicateError, NodeNotFound
from .node import NodeInfo
Expand All @@ -22,7 +23,7 @@
logger = logging.getLogger(__name__)


class NXOntology(Freezable, Generic[NodeT]):
class NXOntologyBase(Freezable, Generic[NodeT, NodeInfoT]):
"""
Encapsulate a networkx.DiGraph to represent an ontology.
Regarding edge directionality, parent terms should point to child term.
Expand Down Expand Up @@ -77,7 +78,7 @@ def write_node_link_json(self, path: str | PathLike[str]) -> None:
write_file.write("\n") # json.dump does not include a trailing newline

@classmethod
def read_node_link_json(cls, path: str | PathLike[str]) -> NXOntology[NodeT]:
def read_node_link_json(cls, path: str | PathLike[str]) -> NXOntologyBase[NodeT]:
"""
Retrun a new graph from node-link format as written by `write_node_link_json`.
"""
Expand Down Expand Up @@ -213,7 +214,8 @@ def compute_similarities(
yield metrics

@classmethod
def _get_node_info_cls(cls) -> type[NodeInfo[NodeT]]:
@abstractmethod
def _get_node_info_cls(cls) -> type[NodeInfoT]:
"""
Return the Node_Info class to use for this ontology.
Subclasses can override this to use a custom Node_Info class.
Expand All @@ -222,7 +224,7 @@ def _get_node_info_cls(cls) -> type[NodeInfo[NodeT]]:
"""
return NodeInfo

def node_info(self, node: NodeT) -> NodeInfo[NodeT]:
def node_info(self, node: NodeT) -> NodeInfoT:
"""
Return Node_Info instance for `node`.
If frozen, cache node info in `self._node_info_cache`.
Expand Down Expand Up @@ -306,3 +308,9 @@ def set_graph_attributes(
self.graph.graph["node_identifier_attribute"] = node_identifier_attribute
if node_url_attribute:
self.graph.graph["node_url_attribute"] = node_url_attribute


class NXOntology(NXOntologyBase[NodeT, NodeInfo[NodeT]]):
@classmethod
def _get_node_info_cls(cls) -> type[NodeInfo[NodeT]]:
return NodeInfo
10 changes: 6 additions & 4 deletions nxontology/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import TYPE_CHECKING, Any, Generic

if TYPE_CHECKING:
from nxontology.ontology import NXOntology
from nxontology.ontology import NXOntologyBase

from networkx import shortest_path_length

from nxontology.node import NodeInfo, NodeT
from nxontology.node import NodeInfo, NodeInfoT, NodeT
from nxontology.utils import Freezable, cache_on_frozen


Expand All @@ -29,7 +29,9 @@ class Similarity(Freezable, Generic[NodeT]):
"batet_log",
]

def __init__(self, nxo: NXOntology[NodeT], node_0: NodeT, node_1: NodeT):
def __init__(
self, nxo: NXOntologyBase[NodeT, NodeInfoT], node_0: NodeT, node_1: NodeT
):
self.nxo = nxo
self.node_0 = node_0
self.node_1 = node_1
Expand Down Expand Up @@ -125,7 +127,7 @@ class SimilarityIC(Similarity[NodeT]):

def __init__(
self,
nxo: NXOntology[NodeT],
nxo: NXOntologyBase[NodeT],
node_0: NodeT,
node_1: NodeT,
ic_metric: str = "intrinsic_ic_sanchez",
Expand Down

0 comments on commit f6134ac

Please sign in to comment.