Skip to content

Commit

Permalink
[internals] improve type safety when using NodeMatcher (#12034)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieleades committed Mar 2, 2024
1 parent 265ffee commit bea8b6b
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 31 deletions.
2 changes: 1 addition & 1 deletion sphinx/builders/html/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def run(self, **kwargs: Any) -> None:
matcher = NodeMatcher(nodes.literal, classes=["kbd"])
# this list must be pre-created as during iteration new nodes
# are added which match the condition in the NodeMatcher.
for node in list(self.document.findall(matcher)): # type: nodes.literal
for node in list(matcher.findall(self.document)):
parts = self.pattern.split(node[-1].astext())
if len(parts) == 1 or self.is_multiwords_key(parts):
continue
Expand Down
6 changes: 3 additions & 3 deletions sphinx/builders/latex/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class FootnoteDocnameUpdater(SphinxTransform):

def apply(self, **kwargs: Any) -> None:
matcher = NodeMatcher(*self.TARGET_NODES)
for node in self.document.findall(matcher): # type: Element
for node in matcher.findall(self.document):
node['docname'] = self.env.docname


Expand Down Expand Up @@ -538,7 +538,7 @@ class CitationReferenceTransform(SphinxPostTransform):
def run(self, **kwargs: Any) -> None:
domain = cast(CitationDomain, self.env.get_domain('citation'))
matcher = NodeMatcher(addnodes.pending_xref, refdomain='citation', reftype='ref')
for node in self.document.findall(matcher): # type: addnodes.pending_xref
for node in matcher.findall(self.document):
docname, labelid, _ = domain.citations.get(node['reftarget'], ('', '', 0))
if docname:
citation_ref = nodes.citation_reference('', '', *node.children,
Expand Down Expand Up @@ -574,7 +574,7 @@ class LiteralBlockTransform(SphinxPostTransform):

def run(self, **kwargs: Any) -> None:
matcher = NodeMatcher(nodes.container, literal_block=True)
for node in self.document.findall(matcher): # type: nodes.container
for node in matcher.findall(self.document):
newnode = captioned_literal_block('', *node.children, **node.attributes)
node.replace_self(newnode)

Expand Down
28 changes: 12 additions & 16 deletions sphinx/transforms/i18n.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def update_title_mapping(self) -> bool:

# replace target's refname to new target name
matcher = NodeMatcher(nodes.target, refname=old_name)
for old_target in self.document.findall(matcher): # type: nodes.target
for old_target in matcher.findall(self.document):
old_target['refname'] = new_name

processed = True
Expand All @@ -198,10 +198,8 @@ def list_replace_or_append(lst: list[N], old: N, new: N) -> None:
lst.append(new)

is_autofootnote_ref = NodeMatcher(nodes.footnote_reference, auto=Any)
old_foot_refs: list[nodes.footnote_reference] = [
*self.node.findall(is_autofootnote_ref)]
new_foot_refs: list[nodes.footnote_reference] = [
*self.patch.findall(is_autofootnote_ref)]
old_foot_refs = list(is_autofootnote_ref.findall(self.node))
new_foot_refs = list(is_autofootnote_ref.findall(self.patch))
self.compare_references(old_foot_refs, new_foot_refs,
__('inconsistent footnote references in translated message.' +
' original: {0}, translated: {1}'))
Expand Down Expand Up @@ -240,8 +238,8 @@ def update_refnamed_references(self) -> None:
# * use translated refname for section refname.
# * inline reference "`Python <...>`_" has no 'refname'.
is_refnamed_ref = NodeMatcher(nodes.reference, refname=Any)
old_refs: list[nodes.reference] = [*self.node.findall(is_refnamed_ref)]
new_refs: list[nodes.reference] = [*self.patch.findall(is_refnamed_ref)]
old_refs = list(is_refnamed_ref.findall(self.node))
new_refs = list(is_refnamed_ref.findall(self.patch))
self.compare_references(old_refs, new_refs,
__('inconsistent references in translated message.' +
' original: {0}, translated: {1}'))
Expand All @@ -264,10 +262,8 @@ def update_refnamed_references(self) -> None:
def update_refnamed_footnote_references(self) -> None:
# refnamed footnote should use original 'ids'.
is_refnamed_footnote_ref = NodeMatcher(nodes.footnote_reference, refname=Any)
old_foot_refs: list[nodes.footnote_reference] = [*self.node.findall(
is_refnamed_footnote_ref)]
new_foot_refs: list[nodes.footnote_reference] = [*self.patch.findall(
is_refnamed_footnote_ref)]
old_foot_refs = list(is_refnamed_footnote_ref.findall(self.node))
new_foot_refs = list(is_refnamed_footnote_ref.findall(self.patch))
refname_ids_map: dict[str, list[str]] = {}
self.compare_references(old_foot_refs, new_foot_refs,
__('inconsistent footnote references in translated message.' +
Expand All @@ -282,8 +278,8 @@ def update_refnamed_footnote_references(self) -> None:
def update_citation_references(self) -> None:
# citation should use original 'ids'.
is_citation_ref = NodeMatcher(nodes.citation_reference, refname=Any)
old_cite_refs: list[nodes.citation_reference] = [*self.node.findall(is_citation_ref)]
new_cite_refs: list[nodes.citation_reference] = [*self.patch.findall(is_citation_ref)]
old_cite_refs = list(is_citation_ref.findall(self.node))
new_cite_refs = list(is_citation_ref.findall(self.patch))
self.compare_references(old_cite_refs, new_cite_refs,
__('inconsistent citation references in translated message.' +
' original: {0}, translated: {1}'))
Expand Down Expand Up @@ -549,7 +545,7 @@ def apply(self, **kwargs: Any) -> None:
return

total = translated = 0
for node in self.document.findall(NodeMatcher(translated=Any)): # type: nodes.Element
for node in NodeMatcher(nodes.Element, translated=Any).findall(self.document):
total += 1
if node['translated']:
translated += 1
Expand Down Expand Up @@ -588,7 +584,7 @@ def apply(self, **kwargs: Any) -> None:
'True, False, "translated" or "untranslated"')
raise ConfigError(msg)

for node in self.document.findall(NodeMatcher(translated=Any)): # type: nodes.Element
for node in NodeMatcher(nodes.Element, translated=Any).findall(self.document):
if node['translated']:
if add_translated:
node.setdefault('classes', []).append('translated')
Expand All @@ -610,7 +606,7 @@ def apply(self, **kwargs: Any) -> None:
return

matcher = NodeMatcher(nodes.inline, translatable=Any)
for inline in list(self.document.findall(matcher)): # type: nodes.inline
for inline in matcher.findall(self.document):
inline.parent.remove(inline)
inline.parent += inline.children

Expand Down
30 changes: 20 additions & 10 deletions sphinx/util/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
import contextlib
import re
import unicodedata
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar

from docutils import nodes
from docutils.nodes import Node

from sphinx import addnodes
from sphinx.locale import __
from sphinx.util import logging

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Iterator

from docutils.nodes import Element, Node
from docutils.nodes import Element
from docutils.parsers.rst import Directive
from docutils.parsers.rst.states import Inliner
from docutils.statemachine import StringList
Expand All @@ -33,7 +34,10 @@
caption_ref_re = explicit_title_re # b/w compat alias


class NodeMatcher:
N = TypeVar("N", bound=Node)


class NodeMatcher(Generic[N]):
"""A helper class for Node.findall().
It checks that the given node is an instance of the specified node-classes and
Expand All @@ -43,20 +47,18 @@ class NodeMatcher:
and ``reftype`` attributes::
matcher = NodeMatcher(nodes.reference, refdomain='std', reftype='citation')
doctree.findall(matcher)
matcher.findall(doctree)
# => [<reference ...>, <reference ...>, ...]
A special value ``typing.Any`` matches any kind of node-attributes. For example,
following example searches ``reference`` node having ``refdomain`` attributes::
from __future__ import annotations
from typing import TYPE_CHECKING, Any
matcher = NodeMatcher(nodes.reference, refdomain=Any)
doctree.findall(matcher)
matcher.findall(doctree)
# => [<reference ...>, <reference ...>, ...]
"""

def __init__(self, *node_classes: type[Node], **attrs: Any) -> None:
def __init__(self, *node_classes: type[N], **attrs: Any) -> None:
self.classes = node_classes
self.attrs = attrs

Expand Down Expand Up @@ -85,6 +87,14 @@ def match(self, node: Node) -> bool:
def __call__(self, node: Node) -> bool:
return self.match(node)

def findall(self, node: Node) -> Iterator[N]:
"""An alternative to `Node.findall` with improved type safety.
While the `NodeMatcher` object can be used as an argument to `Node.findall`, doing so
confounds type checkers' ability to determine the return type of the iterator.
"""
return node.findall(self)


def get_full_module_name(node: Node) -> str:
"""
Expand Down Expand Up @@ -308,7 +318,7 @@ def traverse_translatable_index(
) -> Iterable[tuple[Element, list[tuple[str, str, str, str, str | None]]]]:
"""Traverse translatable index node from a document tree."""
matcher = NodeMatcher(addnodes.index, inline=False)
for node in doctree.findall(matcher): # type: addnodes.index
for node in matcher.findall(doctree):
if 'raw_entries' in node:
entries = node['raw_entries']
else:
Expand Down
2 changes: 1 addition & 1 deletion sphinx/writers/manpage.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, document: nodes.document) -> None:

def apply(self, **kwargs: Any) -> None:
matcher = NodeMatcher(nodes.literal, nodes.emphasis, nodes.strong)
for node in list(self.document.findall(matcher)): # type: nodes.TextElement
for node in list(matcher.findall(self.document)):
if any(matcher(subnode) for subnode in node):
pos = node.parent.index(node)
for subnode in reversed(list(node)):
Expand Down

0 comments on commit bea8b6b

Please sign in to comment.