Skip to content

Commit

Permalink
Split test of handle() into multiple tests
Browse files Browse the repository at this point in the history
  • Loading branch information
De117 committed Dec 11, 2022
1 parent 3837311 commit 64b33cc
Showing 1 changed file with 93 additions and 56 deletions.
149 changes: 93 additions & 56 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import timedelta
from functools import partial
from hypothesis import given, settings, HealthCheck, assume
from hypothesis.strategies import data, integers, binary, floats, lists, builds, sets, tuples
from hypothesis.strategies import data, integers, binary, floats, lists, builds, sets, tuples, sampled_from

from typing import Set, Tuple, List

Expand Down Expand Up @@ -157,44 +157,41 @@ async def getnameinfo(self, sockaddr, flags):
return await trio.socket.getnameinfo(sockaddr, flags)


# TODO: instead of using bytes directly, use an h11 client?
async def connect(stream: trio.abc.Stream, host: Domain, port: Port, expected: bytes, method: str = "CONNECT") -> None:
"""Connect, assert it's OK, quit."""
hostname = f"{host}:{port}"
async with stream:
await stream.send_all(f"{method} {hostname} HTTP/1.1\r\nHost: {hostname}\r\n\r\n".encode())
resp = await stream.receive_some(10000)
assert resp.startswith(expected)

async def connect_slowly(stream: trio.abc.Stream, host: Domain, port: Port, expected: bytes) -> None:
async with stream:
await trio.sleep(10)
try:
# This will blow up if the server already closed the connection.
await stream.send_all(f"CONNECT {host}:{port} HTTP/1.1\r\nHost: whatever\r\n\r\n".encode())
except trio.BrokenResourceError:
pass
resp = await stream.receive_some(10000)
assert resp.startswith(expected)

async def connect_with_bytes(stream: trio.abc.Stream, host: Domain, port: Port, expected: bytes, to_send: bytes) -> None:
async with stream:
await stream.send_all(to_send)
resp = await stream.receive_some(10000)
assert resp.startswith(expected)

async def accept_and_close_connection(s: trio.socket.SocketType) -> None:
await s.accept()
s.close()


@given(domains=sets(domains(), min_size=1))
@settings(deadline=timedelta(seconds=1000), suppress_health_check=[HealthCheck.function_scoped_fixture])
async def test_handle(domains: Set[Domain], autojump_clock) -> None:
async def test_connect_to_whitelisted_host(domains: Set[Domain]) -> None:

# TODO: instead of using bytes directly, use an h11 client?
async def connect_OK(stream: trio.abc.Stream, host: Domain, port: Port, expected: bytes) -> None:
"""Connect, assert it's OK, quit."""
async with stream:
await stream.send_all(f"CONNECT {host}:{port} HTTP/1.1\r\nHost: whatever\r\n\r\n".encode())
resp = await stream.receive_some(10000)
assert resp.startswith(expected)

async def connect_slowly(stream: trio.abc.Stream, host: Domain, port: Port, expected: bytes) -> None:
async with stream:
await trio.sleep(10)
try:
# This will blow up if the server already closed the connection.
await stream.send_all(f"CONNECT {host}:{port} HTTP/1.1\r\nHost: whatever\r\n\r\n".encode())
except trio.BrokenResourceError:
pass
resp = await stream.receive_some(10000)
assert resp.startswith(expected)

async def connect_with_PUT(stream: trio.abc.Stream, host: Domain, port: Port, expected: bytes) -> None:
async with stream:
await stream.send_all(f"PUT {host}:{port} HTTP/1.1\r\nHost: whatever\r\n\r\n".encode())
resp = await stream.receive_some(10000)
assert resp.startswith(expected)

async def connect_with_bytes(stream: trio.abc.Stream, host: Domain, port: Port, expected: bytes, to_send: bytes) -> None:
async with stream:
await stream.send_all(to_send)
resp = await stream.receive_some(10000)
assert resp.startswith(expected)

async def accept_and_close_connection(s: trio.socket.SocketType) -> None:
await s.accept()
s.close()
expected = b"HTTP/1.1 200 Connection established\r\n"

# We open a socket at an OS-chosen port, and override domain resolution.
original_resolver = trio.socket.set_custom_hostname_resolver(ResolveAllToLocalhost())
Expand All @@ -212,15 +209,21 @@ async def accept_and_close_connection(s: trio.socket.SocketType) -> None:
whitelist = {(d, port) for d in domains}
is_whitelisted = lambda host, port: (host, port) in whitelist

# whitelisted hostname
expected = b"HTTP/1.1 200 Connection established\r\n"
client_stream, proxy_stream = trio.testing.memory_stream_pair()
async with trio.open_nursery() as nursery:
nursery.start_soon(connect_OK, client_stream, host, port, expected)
nursery.start_soon(connect, client_stream, host, port, expected)
nursery.start_soon(handle, proxy_stream, is_whitelisted)
nursery.start_soon(accept_and_close_connection, sock)
finally:
trio.socket.set_custom_hostname_resolver(original_resolver)


@given(domains=sets(domains(), min_size=1))
async def test_connect_to_non_whitelisted_host(domains: Set[Domain]) -> None:
expected = b"HTTP/1.1 403"

original_resolver = trio.socket.set_custom_hostname_resolver(ResolveAllToLocalhost())
try:
# Connect to a non-whitelisted hostname:
# you should get back a 403 error
with trio.socket.socket() as sock:
Expand All @@ -234,14 +237,20 @@ async def accept_and_close_connection(s: trio.socket.SocketType) -> None:
whitelist = {(d, port) for d in domains}
is_whitelisted = lambda host, port: (host, port) in whitelist

# non-whitelisted hostname
expected = b"HTTP/1.1 403"
client_stream, proxy_stream = trio.testing.memory_stream_pair()
async with trio.open_nursery() as nursery:
nursery.start_soon(connect_OK, client_stream, "non-whitelisted.domain", port, expected)
nursery.start_soon(connect, client_stream, "non-whitelisted.domain", port, expected)
nursery.start_soon(handle, proxy_stream, is_whitelisted)
finally:
trio.socket.set_custom_hostname_resolver(original_resolver)


@given(domains=sets(domains(), min_size=1))
async def test_connect_to_whitelisted_nonexistent_upstream_host(domains: Set[Domain]) -> None:
expected = b"HTTP/1.1 502"

original_resolver = trio.socket.set_custom_hostname_resolver(ResolveAllToLocalhost())
try:
# Connect to a whitelisted, but non-existent upstream:
# you should get back a 502 error (if upstream is down)
with trio.socket.socket() as sock:
Expand All @@ -253,20 +262,26 @@ async def accept_and_close_connection(s: trio.socket.SocketType) -> None:
whitelist = {(d, port) for d in domains}
is_whitelisted = lambda host, port: (host, port) in whitelist

# whitelisted hostname
expected = b"HTTP/1.1 502"
client_stream, proxy_stream = trio.testing.memory_stream_pair()
async with trio.open_nursery() as nursery:
nursery.start_soon(connect_OK, client_stream, host, port, expected)
nursery.start_soon(connect, client_stream, host, port, expected)
nursery.start_soon(handle, proxy_stream, is_whitelisted)
finally:
trio.socket.set_custom_hostname_resolver(original_resolver)


@given(domains=sets(domains(), min_size=1))
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
async def test_connect_to_whitelisted_slow_upstream_host(domains: Set[Domain], autojump_clock) -> None:
expected = b"HTTP/1.1 504"

original_resolver = trio.socket.set_custom_hostname_resolver(ResolveAllToLocalhost())
try:
# Connect to a whitelisted, but non-existent upstream:
# you should get back a 504 error (if upstream is up but times out)
#
# XXX: we cannot use raw IP sockets as a normal user. But at least on
# UNIX-like systems, we can ensure a slow TCP handshake.

with trio.socket.socket() as sock:
await sock.bind(("localhost", 0))
sock.listen(0) # type: ignore
Expand All @@ -293,14 +308,21 @@ async def accept_and_close_connection(s: trio.socket.SocketType) -> None:
whitelist = {(d, port) for d in domains}
is_whitelisted = lambda host, port: (host, port) in whitelist

# whitelisted hostname
expected = b"HTTP/1.1 504"
client_stream, proxy_stream = trio.testing.memory_stream_pair()
async with trio.open_nursery() as nursery:
nursery.start_soon(connect_OK, client_stream, host, port, expected)
nursery.start_soon(connect, client_stream, host, port, expected)
nursery.start_soon(handle, proxy_stream, is_whitelisted)
finally:
trio.socket.set_custom_hostname_resolver(original_resolver)


@given(domains=sets(domains(), min_size=1))
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
async def test_connect_where_client_times_out(domains: Set[Domain], autojump_clock) -> None:
expected = b"HTTP/1.1 408"

original_resolver = trio.socket.set_custom_hostname_resolver(ResolveAllToLocalhost())
try:
# Anything else should fail with a 4xx error
# client timeouts: 408 (Too Slow)
with trio.socket.socket() as sock:
Expand All @@ -314,14 +336,23 @@ async def accept_and_close_connection(s: trio.socket.SocketType) -> None:
whitelist = {(d, port) for d in domains}
is_whitelisted = lambda host, port: (host, port) in whitelist

# whitelisted hostname
expected = b"HTTP/1.1 408"
client_stream, proxy_stream = trio.testing.memory_stream_pair()
async with trio.open_nursery() as nursery:
nursery.start_soon(connect_slowly, client_stream, host, port, expected)
nursery.start_soon(handle, proxy_stream, is_whitelisted)
finally:
trio.socket.set_custom_hostname_resolver(original_resolver)


HTTP_METHODS = ["GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"]

@given(domains=sets(domains(), min_size=1), method=sampled_from(HTTP_METHODS + ["DJWAODUIJAW"]))
async def test_connect_with_bad_method(domains: Set[Domain], method: str) -> None:
assume(method != "CONNECT")
expected = b"HTTP/1.1 405"

original_resolver = trio.socket.set_custom_hostname_resolver(ResolveAllToLocalhost())
try:
# bad method: 405 (Not Allowed)
with trio.socket.socket() as sock:
await sock.bind(("localhost", 0))
Expand All @@ -334,13 +365,20 @@ async def accept_and_close_connection(s: trio.socket.SocketType) -> None:
whitelist = {(d, port) for d in domains}
is_whitelisted = lambda host, port: (host, port) in whitelist

# whitelisted hostname
expected = b"HTTP/1.1 405"
client_stream, proxy_stream = trio.testing.memory_stream_pair()
async with trio.open_nursery() as nursery:
nursery.start_soon(connect_with_PUT, client_stream, host, port, expected)
nursery.start_soon(connect, client_stream, host, port, expected, method)
nursery.start_soon(handle, proxy_stream, is_whitelisted)
finally:
trio.socket.set_custom_hostname_resolver(original_resolver)


@given(domains=sets(domains(), min_size=1))
async def test_connect_with_random_input(domains: Set[Domain]) -> None:
expected = b"HTTP/1.1 400"

original_resolver = trio.socket.set_custom_hostname_resolver(ResolveAllToLocalhost())
try:
# malformed request is 400 (Bad Request)
random_length = random.randint(0, 100)
random_bytes = bytes(random.getrandbits(8) for _ in range(random_length))
Expand All @@ -357,7 +395,6 @@ async def accept_and_close_connection(s: trio.socket.SocketType) -> None:
whitelist = {(d, port) for d in domains}
is_whitelisted = lambda host, port: (host, port) in whitelist

expected = b"HTTP/1.1 400"
client_stream, proxy_stream = trio.testing.memory_stream_pair()
async with trio.open_nursery() as nursery:
nursery.start_soon(connect_with_bytes, client_stream, host, port, expected, random_bytes)
Expand Down

0 comments on commit 64b33cc

Please sign in to comment.