Skip to content

Commit

Permalink
feat: add eof and close
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrokiefer committed Apr 7, 2018
1 parent 96eefd1 commit c90e567
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 4 deletions.
19 changes: 17 additions & 2 deletions aiostomp/aiostomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ async def reconnect(self, username=None, password=None):
else:
logger.error('All connections attempts failed.')

def close(self):
self._protocol.close()

def connection_lost(self, exc):
self._connected = False
asyncio.ensure_future(self.reconnect())
Expand Down Expand Up @@ -155,7 +158,16 @@ def __init__(self, frame_handler,
self._connect_headers['passcode'] = password

def close(self):
pass
self._transport = None

if self.heartbeater:
self.heartbeater.shutdown()
self.heartbeater = None

if self._task_handler:
self._task_handler.cancel()

self._task_handler = None

def connect(self):
buf = self._protocol.build_frame(
Expand Down Expand Up @@ -266,7 +278,7 @@ def data_received(self, data):
self._waiter.set_result(None)

def eof_received(self):
pass
self.connection_lost(Exception('Got EOF from server'))

async def start(self):
loop = self._loop
Expand Down Expand Up @@ -328,6 +340,9 @@ async def connect(self, username=None, password=None):
self._transport = trans
self._protocol = proto

def close(self):
self._protocol.close()

def subscribe(self, subscription):
headers = {
'id': subscription.id,
Expand Down
58 changes: 56 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ async def test_connection_can_be_made(self, connect_mock):

connect_mock.assert_called_once()

@unittest_run_loop
async def test_connection_can_be_lost(self):
def test_connection_can_be_lost(self):
frame_handler = Mock()
heartbeater = Mock()

Expand All @@ -44,6 +43,48 @@ async def test_connection_can_be_lost(self):
heartbeater.shutdown.assert_called_once()
frame_handler.connection_lost.assert_called_with(exc)

def test_connection_can_be_lost_no_heartbeat(self):
frame_handler = Mock()
heartbeater = Mock()

stomp = StompReader(frame_handler, self.loop)
stomp.heartbeater = None
exc = Exception()

stomp.connection_lost(exc)

heartbeater.shutdown.assert_not_called()
frame_handler.connection_lost.assert_called_with(exc)

def test_can_close_connection(self):
frame_handler = Mock()
heartbeater = Mock()

stomp = StompReader(frame_handler, self.loop)
stomp.heartbeater = heartbeater

stomp.close()

heartbeater.shutdown.assert_called_once()

def test_can_close_connection_no_heartbeat(self):
frame_handler = Mock()
heartbeater = Mock()

stomp = StompReader(frame_handler, self.loop)
stomp.heartbeater = None

stomp.close()

heartbeater.shutdown.assert_not_called()

@patch('aiostomp.aiostomp.StompReader.connection_lost')
def test_can_receive_eof(self, connection_lost_mock):
stomp = StompReader(None, self.loop)
stomp.eof_received()

connection_lost_mock.assert_called_once()

@unittest_run_loop
async def test_send_frame_can_raise_error(self):
stomp = StompReader(None, self.loop)
Expand Down Expand Up @@ -426,6 +467,12 @@ async def test_can_reconnect_on_connection_lost(self):

self.stomp.reconnect.assert_called_once()

@patch('aiostomp.aiostomp.StompProtocol.close')
def test_can_close_connection(self, close_mock):
self.stomp.close()

close_mock.assert_called_once()

def test_can_subscribe(self):
self.stomp._protocol.subscribe = Mock()

Expand Down Expand Up @@ -576,6 +623,13 @@ async def test_can_create_a_connection(self):
self.protocol._factory, host='127.0.0.1', port=61613,
ssl=None)

@unittest_run_loop
async def test_can_close(self):
await self.protocol.connect()
self.protocol.close()

self._protocol.close.assert_called_once()

@unittest_run_loop
async def test_can_create_a_connection_with_ssl_context(self):
ssl_context = ssl.create_default_context()
Expand Down
14 changes: 14 additions & 0 deletions tests/test_recv_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ def test_on_decode_error_show_string(self):
with self.assertRaises(UnicodeDecodeError):
self.protocol._decode(data)

def test_can_reset(self):
self.protocol.feed_data(
b'CONNECT\n'
b'accept-version:1.0\n\n\x00'
)

self.assertEqual(len(self.protocol._pending_parts), 0)
self.assertEqual(len(self.protocol._frames_ready), 1)

self.protocol.reset()

self.assertEqual(len(self.protocol._pending_parts), 0)
self.assertEqual(len(self.protocol._frames_ready), 0)

def test_single_packet(self):
self.protocol.feed_data(
b'CONNECT\n'
Expand Down

0 comments on commit c90e567

Please sign in to comment.