Skip to content

Commit

Permalink
Move configuration parsing to separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
De117 committed Jan 10, 2023
1 parent 5a4ce7c commit 9754cea
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 80 deletions.
3 changes: 2 additions & 1 deletion tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ async def killer() -> None:
nursery.start_soon(runner)
nursery.start_soon(killer)

from whitelisting_proxy.proxy import Domain, Port, handle
from whitelisting_proxy.proxy import handle
from whitelisting_proxy.config import Domain, Port

# Tests for the main handler follow.
#
Expand Down
2 changes: 2 additions & 0 deletions whitelisting_proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .proxy import (
WhitelistingProxy,
SynchronousWhitelistingProxy,
)
from .config import (
Port,
Domain,
Configuration,
Expand Down
3 changes: 2 additions & 1 deletion whitelisting_proxy/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import trio, threading, signal, argparse, sys
from functools import partial
from typing import Union, Tuple, Set
from .proxy import WhitelistingProxy, SynchronousWhitelistingProxy, load_configuration_from_file
from .proxy import WhitelistingProxy, SynchronousWhitelistingProxy
from .config import load_configuration_from_file

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
78 changes: 78 additions & 0 deletions whitelisting_proxy/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import re, json
from dataclasses import dataclass
from typing import Union, Set, Tuple

################################################################
# Configuration parsing
################################################################

class Port(int):
def __init__(self, n: Union[int, str]):
try:
if isinstance(n, str):
n = int(n)
except ValueError:
raise ValueError(f"Invalid port number: {n}")

if n not in range(65536):
raise ValueError(f"Invalid port number: {n}")

class Domain(str):
def __init__(self, s: str):
# Grammar taken from RFC 1035, §2.3.1
# (The "subdomain" node, because we do not allow the empty string.)
letter_digit_hyphen = "[a-zA-Z0-9-]"
letter_digit = "[a-zA-Z0-9]"
letter = "[a-zA-Z]"
label = f"{letter}({letter_digit_hyphen}*{letter_digit}+)*"
domain = f"({label}\\.)*{label}"

if not re.match(f"^{domain}$", s):
raise ValueError("Malformed domain: " + s)

@dataclass
class Configuration:
version: int
whitelist: Set[Tuple[Domain, Port]]


def parse_host_and_port(s: str) -> Tuple[Domain, Port]:
"""
Parses a "host[:port]" string into a (host, port) pair.
Default port is 80.
Raises ValueError if parsing fails.
"""
if ":" in s:
h, p = s.split(":", maxsplit=1)
return Domain(h), Port(p)
else:
return Domain(s), Port(80)


def parse_configuration_v1(b: bytes) -> Configuration:
"""
Parses configuration. Raises ValueError if it fails.
"""
try:
config = json.loads(b)
assert "version" in config, "Missing field: version"
assert config["version"] == 1, "Unsupported version"
assert "allowed_hosts" in config, "Missing field: allowed_hosts"
hostnames = config["allowed_hosts"]
assert isinstance(hostnames, list), "Malformed field: hostname"
whitelist = {parse_host_and_port(h) for h in hostnames}
return Configuration(version=1, whitelist=whitelist)
except (json.JSONDecodeError, AssertionError, ValueError) as e:
raise ValueError("Could not parse v1 configuration") from e


def load_configuration_from_file(filename: str) -> Configuration:
"""
Loads configuration from file. Raises RuntimeError if it fails.
"""
try:
with open(filename, "rb") as f:
return parse_configuration_v1(f.read())
except (FileNotFoundError, ValueError) as e:
raise RuntimeError(f"Could not open or parse configuration file: {e}") from e
81 changes: 3 additions & 78 deletions whitelisting_proxy/proxy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import trio, h11, threading, re, json
import trio, h11, threading
from .adapter import TrioHTTPConnection
from .config import parse_host_and_port
from functools import partial
from dataclasses import dataclass
from typing import Callable, Union, Iterable, Tuple, Set
from typing import Callable, Union, Iterable, Tuple

################################################################
# The proxy itself
Expand Down Expand Up @@ -254,78 +254,3 @@ def stop(self) -> None:
"""Stop the proxy, if not already stopped."""
self._stop.set()
self._thread.join()

################################################################
# Configuration parsing
################################################################

class Port(int):
def __init__(self, n: Union[int, str]):
try:
if isinstance(n, str):
n = int(n)
except ValueError:
raise ValueError(f"Invalid port number: {n}")

if n not in range(65536):
raise ValueError(f"Invalid port number: {n}")

class Domain(str):
def __init__(self, s: str):
# Grammar taken from RFC 1035, §2.3.1
# (The "subdomain" node, because we do not allow the empty string.)
letter_digit_hyphen = "[a-zA-Z0-9-]"
letter_digit = "[a-zA-Z0-9]"
letter = "[a-zA-Z]"
label = f"{letter}({letter_digit_hyphen}*{letter_digit}+)*"
domain = f"({label}\\.)*{label}"

if not re.match(f"^{domain}$", s):
raise ValueError("Malformed domain: " + s)

@dataclass
class Configuration:
version: int
whitelist: Set[Tuple[Domain, Port]]


def parse_host_and_port(s: str) -> Tuple[Domain, Port]:
"""
Parses a "host[:port]" string into a (host, port) pair.
Default port is 80.
Raises ValueError if parsing fails.
"""
if ":" in s:
h, p = s.split(":", maxsplit=1)
return Domain(h), Port(p)
else:
return Domain(s), Port(80)


def parse_configuration_v1(b: bytes) -> Configuration:
"""
Parses configuration. Raises ValueError if it fails.
"""
try:
config = json.loads(b)
assert "version" in config, "Missing field: version"
assert config["version"] == 1, "Unsupported version"
assert "allowed_hosts" in config, "Missing field: allowed_hosts"
hostnames = config["allowed_hosts"]
assert isinstance(hostnames, list), "Malformed field: hostname"
whitelist = {parse_host_and_port(h) for h in hostnames}
return Configuration(version=1, whitelist=whitelist)
except (json.JSONDecodeError, AssertionError, ValueError) as e:
raise ValueError("Could not parse v1 configuration") from e


def load_configuration_from_file(filename: str) -> Configuration:
"""
Loads configuration from file. Raises RuntimeError if it fails.
"""
try:
with open(filename, "rb") as f:
return parse_configuration_v1(f.read())
except (FileNotFoundError, ValueError) as e:
raise RuntimeError(f"Could not open or parse configuration file: {e}") from e

0 comments on commit 9754cea

Please sign in to comment.