Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests] utils: refactor type-hint signatures. #12144

12 changes: 6 additions & 6 deletions tests/test_builders/test_build_linkcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sphinx.testing.util import strip_escseq
from sphinx.util import requests

from tests.utils import CERT_FILE, http_server, https_server
from tests.utils import CERT_FILE, http_server

ts_re = re.compile(r".*\[(?P<ts>.*)\].*")

Expand Down Expand Up @@ -633,7 +633,7 @@ def test_invalid_ssl(get_request, app):

@pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True)
def test_connect_to_selfsigned_fails(app):
with https_server(OKHandler):
with http_server(OKHandler, tls_enabled=True):
app.build()

with open(app.outdir / 'output.json', encoding='utf-8') as fp:
Expand All @@ -648,7 +648,7 @@ def test_connect_to_selfsigned_fails(app):
@pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True)
def test_connect_to_selfsigned_with_tls_verify_false(app):
app.config.tls_verify = False
with https_server(OKHandler):
with http_server(OKHandler, tls_enabled=True):
app.build()

with open(app.outdir / 'output.json', encoding='utf-8') as fp:
Expand All @@ -666,7 +666,7 @@ def test_connect_to_selfsigned_with_tls_verify_false(app):
@pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True)
def test_connect_to_selfsigned_with_tls_cacerts(app):
app.config.tls_cacerts = CERT_FILE
with https_server(OKHandler):
with http_server(OKHandler, tls_enabled=True):
app.build()

with open(app.outdir / 'output.json', encoding='utf-8') as fp:
Expand All @@ -684,7 +684,7 @@ def test_connect_to_selfsigned_with_tls_cacerts(app):
@pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True)
def test_connect_to_selfsigned_with_requests_env_var(monkeypatch, app):
monkeypatch.setenv("REQUESTS_CA_BUNDLE", CERT_FILE)
with https_server(OKHandler):
with http_server(OKHandler, tls_enabled=True):
app.build()

with open(app.outdir / 'output.json', encoding='utf-8') as fp:
Expand All @@ -702,7 +702,7 @@ def test_connect_to_selfsigned_with_requests_env_var(monkeypatch, app):
@pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True)
def test_connect_to_selfsigned_nonexistent_cert_file(app):
app.config.tls_cacerts = "does/not/exist"
with https_server(OKHandler):
with http_server(OKHandler, tls_enabled=True):
app.build()

with open(app.outdir / 'output.json', encoding='utf-8') as fp:
Expand Down
37 changes: 14 additions & 23 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from __future__ import annotations

import contextlib
from contextlib import contextmanager
from http.server import ThreadingHTTPServer
from pathlib import Path
from ssl import PROTOCOL_TLS_SERVER, SSLContext
from threading import Thread
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING

import filelock

if TYPE_CHECKING:
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager
from collections.abc import Iterator
from socketserver import BaseRequestHandler
from typing import Any, Final

Expand Down Expand Up @@ -49,24 +48,16 @@ def __init__(
self.server.socket = sslcontext.wrap_socket(self.server.socket, server_side=True)


_T_co = TypeVar('_T_co', bound=HttpServerThread, covariant=True)
@contextmanager
def http_server(handler: type[BaseRequestHandler], tls_enabled: bool = False) -> Iterator[HttpServerThread]:
jayaddison marked this conversation as resolved.
Show resolved Hide resolved
server_cls = HttpsServerThread if tls_enabled else HttpServerThread
with filelock.FileLock(LOCK_PATH):
server = server_cls(handler, daemon=True)
server.start()
try:
yield server
finally:
server.terminate()


def create_server(
server_thread_class: type[_T_co],
) -> Callable[[type[BaseRequestHandler]], AbstractContextManager[_T_co]]:
@contextlib.contextmanager
def server(handler_class: type[BaseRequestHandler]) -> Generator[_T_co, None, None]:
lock = filelock.FileLock(LOCK_PATH)
with lock:
server_thread = server_thread_class(handler_class, daemon=True)
server_thread.start()
try:
yield server_thread
finally:
server_thread.terminate()
return server


http_server = create_server(HttpServerThread)
https_server = create_server(HttpsServerThread)
__all__ = ["http_server"]
jayaddison marked this conversation as resolved.
Show resolved Hide resolved