diff --git a/nxontology/ontology.py b/nxontology/ontology.py index 4a9b58b..6e1417a 100644 --- a/nxontology/ontology.py +++ b/nxontology/ontology.py @@ -29,8 +29,13 @@ class NXOntology(Freezable, Generic[Node]): Edges should go from general to more specific. """ - def __init__(self, graph: nx.DiGraph | None = None): + def __init__( + self, + graph: nx.DiGraph | None = None, + node_info_class: type[Node_Info[Node]] = Node_Info, + ): self.graph = nx.DiGraph(graph) + self.node_info_class = node_info_class if graph is None: # Store the nxontology version that created the graph as metadata, # in case there are compatability issues in the future. @@ -215,9 +220,9 @@ def node_info(self, node: Node) -> Node_Info[Node]: If frozen, cache node info in `self._node_info_cache`. """ if not self.frozen: - return Node_Info(self, node) + return self.node_info_class(self, node) if node not in self._node_info_cache: - self._node_info_cache[node] = Node_Info(self, node) + self._node_info_cache[node] = self.node_info_class(self, node) return self._node_info_cache[node] @cache_on_frozen diff --git a/nxontology/similarity.py b/nxontology/similarity.py index 48c2298..3ef46f1 100644 --- a/nxontology/similarity.py +++ b/nxontology/similarity.py @@ -125,17 +125,17 @@ class SimilarityIC(Similarity[Node]): def __init__( self, - graph: NXOntology[Node], + nxo: NXOntology[Node], node_0: Node, node_1: Node, ic_metric: str = "intrinsic_ic_sanchez", ): - super().__init__(graph, node_0, node_1) + super().__init__(nxo, node_0, node_1) - if ic_metric not in Node_Info.ic_metrics: + if ic_metric not in nxo.node_info_class.ic_metrics: raise ValueError( f"{ic_metric!r} is not a supported ic_metric. " - f"Choose from: {', '.join(Node_Info.ic_metrics)}." + f"Choose from: {', '.join(nxo.node_info_class.ic_metrics)}." ) self.ic_metric = ic_metric self.ic_metric_scaled = f"{ic_metric}_scaled" diff --git a/nxontology/tests/ontology_test.py b/nxontology/tests/ontology_test.py index 3a67f75..c96a882 100644 --- a/nxontology/tests/ontology_test.py +++ b/nxontology/tests/ontology_test.py @@ -5,6 +5,7 @@ import pytest from nxontology.exceptions import DuplicateError, NodeNotFound +from nxontology.node import Node_Info from nxontology.ontology import NXOntology @@ -151,3 +152,18 @@ def test_node_info_by_name() -> None: def test_node_info_not_found(metal_nxo_frozen: NXOntology[str]) -> None: with pytest.raises(NodeNotFound, match="not-a-metal not in graph"): metal_nxo_frozen.node_info("not-a-metal") + + +def test_custom_node_info_class() -> None: + class CustomNodeInfo(Node_Info[str]): + @property + def custom_property(self) -> str: + return "custom" + + nxo: NXOntology[str] = NXOntology(node_info_class=CustomNodeInfo) + nxo.add_node("a", name="a_name") + assert nxo.node_info("a").custom_property == "custom" # type: ignore [attr-defined] + assert nxo.node_info_by_name("a_name").custom_property == "custom" # type: ignore [attr-defined] + similarity = nxo.similarity("a", "a") + assert similarity.info_0.custom_property == "custom" # type: ignore [attr-defined] + assert similarity.info_1.custom_property == "custom" # type: ignore [attr-defined]