Skip to content

Commit

Permalink
fix: Correctly check allowed host is unicode rather than str (#22)
Browse files Browse the repository at this point in the history
* correctly check allowed host if host is unicode rather than str
* reduce code complexity for parsing hostname
  • Loading branch information
sadams authored and miketheman committed Feb 9, 2019
1 parent b679e28 commit 8971510
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
23 changes: 20 additions & 3 deletions pytest_socket.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
import socket

import sys
import pytest

_true_socket = socket.socket
_true_connect = socket.socket.connect
is_py2 = sys.version_info[0] == 2


class SocketBlockedError(RuntimeError):
Expand Down Expand Up @@ -101,10 +102,26 @@ def pytest_runtest_teardown():
remove_host_restrictions()


def host_from_address(address):
host = address[0]
if isinstance(host, str):
return host


def host_from_address_py2(address):
host = address[0]
if isinstance(host, str) or isinstance(host, unicode): # noqa F821
return host


def host_from_connect_args(args):
address = args[0]
if isinstance(address, tuple) and isinstance(address[0], str):
return address[0]

if isinstance(address, tuple):
if is_py2:
return host_from_address_py2(address)
else:
return host_from_address(address)


def socket_allow_hosts(allowed=None):
Expand Down
15 changes: 14 additions & 1 deletion tests/test_restrict_hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ def {2}():
socket.socket().connect(('{0}', {1}))
"""

connect_unicode_code_template = """
import socket
import pytest
{3}
def {2}():
socket.socket().connect((u'{0}', {1}))
"""

# `contextlib` used because otherwise 2.7 was occasionally hanging due to exception cases:
urlopen_code_template = """
import pytest
Expand Down Expand Up @@ -101,6 +110,10 @@ def test_single_cli_arg_connect_enabled(assert_connect):
assert_connect(True, cli_arg=localhost)


def test_single_cli_arg_connect_unicode_enabled(assert_connect):
assert_connect(True, cli_arg=localhost, code_template=connect_unicode_code_template)


def test_multiple_cli_arg_connect_enabled(assert_connect):
assert_connect(True, cli_arg=localhost + ',1.2.3.4')

Expand All @@ -122,7 +135,7 @@ def test_mark_cli_conflict_mark_wins_connect_enabled(assert_connect):


def test_single_cli_arg_connect_disabled(assert_connect):
assert_connect(False, cli_arg='1.2.3.4')
assert_connect(False, cli_arg='1.2.3.4', code_template=connect_unicode_code_template)


def test_multiple_cli_arg_connect_disabled(assert_connect):
Expand Down

0 comments on commit 8971510

Please sign in to comment.