Skip to content

Commit

Permalink
Properly handle client-dropped connection (+test)
Browse files Browse the repository at this point in the history
It's not logged as a 500 Internal Server Error anymore.
  • Loading branch information
De117 committed Dec 5, 2022
1 parent 6b3c8f5 commit 332e011
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 18 deletions.
60 changes: 43 additions & 17 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ async def killer() -> None:

from whitelisting_proxy.proxy import Domain, Port, handle

# TESTS:
# ======
# TESTS TO DO:
# ============
# Connect to a whitelisted hostname:
# you should have access (200 OK)

Expand Down Expand Up @@ -157,8 +157,9 @@ async def getnameinfo(self, sockaddr, flags):
return await trio.socket.getnameinfo(sockaddr, flags)


@given(sets(domains(), min_size=1))
async def test_full_connect(domains: Set[Domain]) -> None:
@given(domains=sets(domains(), min_size=1))
@settings(deadline=timedelta(seconds=1000), suppress_health_check=[HealthCheck.function_scoped_fixture])
async def test_full_connect(domains: Set[Domain], autojump_clock) -> None:

# TODO: instead of using bytes directly, use an h11 client?
async def connect_OK(stream: trio.abc.Stream, host: Domain, port: Port, expected: bytes) -> None:
Expand All @@ -168,31 +169,56 @@ async def connect_OK(stream: trio.abc.Stream, host: Domain, port: Port, expected
resp = await stream.receive_some(10000)
assert resp.startswith(expected)

expected = b"HTTP/1.1 200 Connection established\r\n"

async def accept_and_close_connection(s: trio.socket.SocketType) -> None:
await s.accept()
s.close()

# We open a socket at an OS-chosen port, and override domain resolution.
with trio.socket.socket() as sock:
await sock.bind(("localhost", 0))
sock.listen() # type: ignore
# (As of trio-typing 0.7.0, the type for listen() is wrong.)

_, port = sock.getsockname()
host: Domain = random.choice(list(domains))
original_resolver = trio.socket.set_custom_hostname_resolver(ResolveAllToLocalhost())
try:
# Connect to a whitelisted hostname:
# you should have access (200 OK)
with trio.socket.socket() as sock:
await sock.bind(("localhost", 0))
sock.listen() # type: ignore
# (As of trio-typing 0.7.0, the type for listen() is wrong.)

_, port = sock.getsockname()
host: Domain = random.choice(list(domains))

original_resolver = trio.socket.set_custom_hostname_resolver(ResolveAllToLocalhost())
try:
whitelist = {(d, port) for d in domains}
is_whitelisted = lambda host, port: (host, port) in whitelist

# whitelisted hostname
expected = b"HTTP/1.1 200 Connection established\r\n"
client_stream, proxy_stream = trio.testing.memory_stream_pair()
async with trio.open_nursery() as nursery:
nursery.start_soon(connect_OK, client_stream, host, port, expected)
nursery.start_soon(handle, proxy_stream, is_whitelisted)
nursery.start_soon(accept_and_close_connection, sock)

finally:
trio.socket.set_custom_hostname_resolver(original_resolver)

# Connect to a non-whitelisted hostname:
# you should get back a 403 error
with trio.socket.socket() as sock:
await sock.bind(("localhost", 0))
sock.listen() # type: ignore
# (As of trio-typing 0.7.0, the type for listen() is wrong.)

_, port = sock.getsockname()
host: Domain = random.choice(list(domains))
print(host, port)

whitelist = {(d, port) for d in domains}
is_whitelisted = lambda host, port: (host, port) in whitelist

# non-whitelisted hostname
expected = b"HTTP/1.1 403"
client_stream, proxy_stream = trio.testing.memory_stream_pair()
async with trio.open_nursery() as nursery:
nursery.start_soon(connect_OK, client_stream, "non-whitelisted.domain", port, expected)
nursery.start_soon(handle, proxy_stream, is_whitelisted)
#nursery.start_soon(accept_and_close_connection, sock)

finally:
trio.socket.set_custom_hostname_resolver(original_resolver)
4 changes: 3 additions & 1 deletion whitelisting_proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ async def handle(stream: trio.SocketStream, is_whitelisted: Callable[[str, int],
except Exception as e:
w.info(f"Handling exception: {e!r}")
try:
if isinstance(e, h11.RemoteProtocolError):
if isinstance(e, trio.BrokenResourceError):
w.info("Client abruptly closed connection; dropping request.")
elif isinstance(e, h11.RemoteProtocolError):
await w.send_error(e.error_status_hint, str(e))
elif isinstance(e, trio.TooSlowError):
if not client_request_completed:
Expand Down

0 comments on commit 332e011

Please sign in to comment.