Skip to content

Commit

Permalink
allow custom Node_Info sublasses
Browse files Browse the repository at this point in the history
  • Loading branch information
dhimmel committed Jul 11, 2023
1 parent 5550fcb commit 05e7c4e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
11 changes: 8 additions & 3 deletions nxontology/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ class NXOntology(Freezable, Generic[Node]):
Edges should go from general to more specific.
"""

def __init__(self, graph: nx.DiGraph | None = None):
def __init__(
self,
graph: nx.DiGraph | None = None,
node_info_class: type[Node_Info[Node]] = Node_Info,
):
self.graph = nx.DiGraph(graph)
self.node_info_class = node_info_class
if graph is None:
# Store the nxontology version that created the graph as metadata,
# in case there are compatability issues in the future.
Expand Down Expand Up @@ -215,9 +220,9 @@ def node_info(self, node: Node) -> Node_Info[Node]:
If frozen, cache node info in `self._node_info_cache`.
"""
if not self.frozen:
return Node_Info(self, node)
return self.node_info_class(self, node)
if node not in self._node_info_cache:
self._node_info_cache[node] = Node_Info(self, node)
self._node_info_cache[node] = self.node_info_class(self, node)
return self._node_info_cache[node]

@cache_on_frozen
Expand Down
8 changes: 4 additions & 4 deletions nxontology/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,17 @@ class SimilarityIC(Similarity[Node]):

def __init__(
self,
graph: NXOntology[Node],
nxo: NXOntology[Node],
node_0: Node,
node_1: Node,
ic_metric: str = "intrinsic_ic_sanchez",
):
super().__init__(graph, node_0, node_1)
super().__init__(nxo, node_0, node_1)

if ic_metric not in Node_Info.ic_metrics:
if ic_metric not in nxo.node_info_class.ic_metrics:
raise ValueError(
f"{ic_metric!r} is not a supported ic_metric. "
f"Choose from: {', '.join(Node_Info.ic_metrics)}."
f"Choose from: {', '.join(nxo.node_info_class.ic_metrics)}."
)
self.ic_metric = ic_metric
self.ic_metric_scaled = f"{ic_metric}_scaled"
Expand Down
16 changes: 16 additions & 0 deletions nxontology/tests/ontology_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from nxontology.exceptions import DuplicateError, NodeNotFound
from nxontology.node import Node_Info
from nxontology.ontology import NXOntology


Expand Down Expand Up @@ -151,3 +152,18 @@ def test_node_info_by_name() -> None:
def test_node_info_not_found(metal_nxo_frozen: NXOntology[str]) -> None:
with pytest.raises(NodeNotFound, match="not-a-metal not in graph"):
metal_nxo_frozen.node_info("not-a-metal")


def test_custom_node_info_class() -> None:
class CustomNodeInfo(Node_Info[str]):
@property
def custom_property(self) -> str:
return "custom"

nxo: NXOntology[str] = NXOntology(node_info_class=CustomNodeInfo)
nxo.add_node("a", name="a_name")
assert nxo.node_info("a").custom_property == "custom" # type: ignore [attr-defined]
assert nxo.node_info_by_name("a_name").custom_property == "custom" # type: ignore [attr-defined]
similarity = nxo.similarity("a", "a")
assert similarity.info_0.custom_property == "custom" # type: ignore [attr-defined]
assert similarity.info_1.custom_property == "custom" # type: ignore [attr-defined]

0 comments on commit 05e7c4e

Please sign in to comment.