Skip to content

Commit

Permalink
use subclass to set _get_node_info_cls
Browse files Browse the repository at this point in the history
  • Loading branch information
dhimmel committed Jul 11, 2023
1 parent 756eb51 commit c503f1c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
17 changes: 13 additions & 4 deletions nxontology/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ class NXOntology(Freezable, Generic[Node]):
def __init__(
self,
graph: nx.DiGraph | None = None,
node_info_class: type[Node_Info[Node]] = Node_Info,
):
self.graph = nx.DiGraph(graph)
self.node_info_class = node_info_class
if graph is None:
# Store the nxontology version that created the graph as metadata,
# in case there are compatability issues in the future.
Expand Down Expand Up @@ -214,15 +212,26 @@ def compute_similarities(
metrics = self.similarity_metrics(node_0, node_1, ic_metric=ic_metric)
yield metrics

@classmethod
def _get_node_info_cls(cls) -> type[Node_Info[Node]]:
"""
Return the Node_Info class to use for this ontology.
Subclasses can override this to use a custom Node_Info class.
For the complexity of typing this method, see
<https://github.com/related-sciences/nxontology/pull/26>.
"""
return Node_Info

def node_info(self, node: Node) -> Node_Info[Node]:
"""
Return Node_Info instance for `node`.
If frozen, cache node info in `self._node_info_cache`.
"""
node_info_cls = self._get_node_info_cls()
if not self.frozen:
return self.node_info_class(self, node)
return node_info_cls(self, node)
if node not in self._node_info_cache:
self._node_info_cache[node] = self.node_info_class(self, node)
self._node_info_cache[node] = node_info_cls(self, node)
return self._node_info_cache[node]

@cache_on_frozen
Expand Down
6 changes: 3 additions & 3 deletions nxontology/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ def __init__(
ic_metric: str = "intrinsic_ic_sanchez",
):
super().__init__(nxo, node_0, node_1)

if ic_metric not in nxo.node_info_class.ic_metrics:
ic_metrics = nxo._get_node_info_cls().ic_metrics
if ic_metric not in ic_metrics:
raise ValueError(
f"{ic_metric!r} is not a supported ic_metric. "
f"Choose from: {', '.join(nxo.node_info_class.ic_metrics)}."
f"Choose from: {', '.join(ic_metrics)}."
)
self.ic_metric = ic_metric
self.ic_metric_scaled = f"{ic_metric}_scaled"
Expand Down
15 changes: 13 additions & 2 deletions nxontology/tests/ontology_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pathlib
from datetime import date
from typing import Type

import networkx
import pytest
Expand Down Expand Up @@ -160,9 +161,19 @@ class CustomNodeInfo(Node_Info[str]):
def custom_property(self) -> str:
return "custom"

nxo: NXOntology[str] = NXOntology(node_info_class=CustomNodeInfo)
class CustomNxontology(NXOntology[str]):
@classmethod
def _get_node_info_cls(cls) -> Type[CustomNodeInfo]:
return CustomNodeInfo

def node_info(self, node: str) -> CustomNodeInfo:
info = super().node_info(node)
assert isinstance(info, CustomNodeInfo)
return info

nxo = CustomNxontology()
nxo.add_node("a", name="a_name")
assert nxo.node_info("a").custom_property == "custom" # type: ignore [attr-defined]
assert nxo.node_info("a").custom_property == "custom"
assert nxo.node_info_by_name("a_name").custom_property == "custom" # type: ignore [attr-defined]
similarity = nxo.similarity("a", "a")
assert similarity.info_0.custom_property == "custom" # type: ignore [attr-defined]
Expand Down

0 comments on commit c503f1c

Please sign in to comment.