Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests] utils: refactor type-hint signatures. #12144

12 changes: 6 additions & 6 deletions tests/test_builders/test_build_linkcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sphinx.testing.util import strip_escseq
from sphinx.util import requests

from tests.utils import CERT_FILE, http_server, https_server
from tests.utils import CERT_FILE, http_server

ts_re = re.compile(r".*\[(?P<ts>.*)\].*")

Expand Down Expand Up @@ -633,7 +633,7 @@ def test_invalid_ssl(get_request, app):

@pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True)
def test_connect_to_selfsigned_fails(app):
with https_server(OKHandler):
with http_server(OKHandler, tls_enabled=True):
app.build()

with open(app.outdir / 'output.json', encoding='utf-8') as fp:
Expand All @@ -648,7 +648,7 @@ def test_connect_to_selfsigned_fails(app):
@pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True)
def test_connect_to_selfsigned_with_tls_verify_false(app):
app.config.tls_verify = False
with https_server(OKHandler):
with http_server(OKHandler, tls_enabled=True):
app.build()

with open(app.outdir / 'output.json', encoding='utf-8') as fp:
Expand All @@ -666,7 +666,7 @@ def test_connect_to_selfsigned_with_tls_verify_false(app):
@pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True)
def test_connect_to_selfsigned_with_tls_cacerts(app):
app.config.tls_cacerts = CERT_FILE
with https_server(OKHandler):
with http_server(OKHandler, tls_enabled=True):
app.build()

with open(app.outdir / 'output.json', encoding='utf-8') as fp:
Expand All @@ -684,7 +684,7 @@ def test_connect_to_selfsigned_with_tls_cacerts(app):
@pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True)
def test_connect_to_selfsigned_with_requests_env_var(monkeypatch, app):
monkeypatch.setenv("REQUESTS_CA_BUNDLE", CERT_FILE)
with https_server(OKHandler):
with http_server(OKHandler, tls_enabled=True):
app.build()

with open(app.outdir / 'output.json', encoding='utf-8') as fp:
Expand All @@ -702,7 +702,7 @@ def test_connect_to_selfsigned_with_requests_env_var(monkeypatch, app):
@pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True)
def test_connect_to_selfsigned_nonexistent_cert_file(app):
app.config.tls_cacerts = "does/not/exist"
with https_server(OKHandler):
with http_server(OKHandler, tls_enabled=True):
app.build()

with open(app.outdir / 'output.json', encoding='utf-8') as fp:
Expand Down
41 changes: 16 additions & 25 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from __future__ import annotations

import contextlib
__all__ = ["http_server"]
jayaddison marked this conversation as resolved.
Show resolved Hide resolved

Copy link
Member

@picnixz picnixz Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should be 1 line after and not two (but I'm not sure with the diff)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could always remove this file from the format exclude list 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't mind waiting until this is merged, we could do that as part of #12146?


from contextlib import contextmanager
from http.server import ThreadingHTTPServer
from pathlib import Path
from ssl import PROTOCOL_TLS_SERVER, SSLContext
from threading import Thread
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING

import filelock

if TYPE_CHECKING:
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager
from collections.abc import Iterator
from socketserver import BaseRequestHandler
from typing import Any, Final

Expand Down Expand Up @@ -49,24 +51,13 @@ def __init__(
self.server.socket = sslcontext.wrap_socket(self.server.socket, server_side=True)


_T_co = TypeVar('_T_co', bound=HttpServerThread, covariant=True)


def create_server(
server_thread_class: type[_T_co],
) -> Callable[[type[BaseRequestHandler]], AbstractContextManager[_T_co]]:
@contextlib.contextmanager
def server(handler_class: type[BaseRequestHandler]) -> Generator[_T_co, None, None]:
lock = filelock.FileLock(LOCK_PATH)
with lock:
server_thread = server_thread_class(handler_class, daemon=True)
server_thread.start()
try:
yield server_thread
finally:
server_thread.terminate()
return server


http_server = create_server(HttpServerThread)
https_server = create_server(HttpsServerThread)
@contextmanager
def http_server(handler: type[BaseRequestHandler], *, tls_enabled: bool = False) -> Iterator[HttpServerThread]:
server_cls = HttpsServerThread if tls_enabled else HttpServerThread
with filelock.FileLock(LOCK_PATH):
server = server_cls(handler, daemon=True)
server.start()
try:
yield server
finally:
server.terminate()