Skip to content

Commit

Permalink
Merge pull request pedrokiefer#6 from krishna-kashyap/master
Browse files Browse the repository at this point in the history
Improved frame parsing and added protocol compliance checks
  • Loading branch information
pedrokiefer committed May 29, 2018
2 parents 2f7e00f + 65c60f9 commit 93beebd
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 9 deletions.
19 changes: 14 additions & 5 deletions aiostomp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, log_name='StompProtocol'):
def _decode(self, byte_data):
try:
if isinstance(byte_data, (bytes, bytearray)):
return byte_data.decode('latin-1')
return byte_data.decode('utf-8')

return byte_data
except UnicodeDecodeError:
Expand Down Expand Up @@ -66,12 +66,21 @@ def _feed_data(self, data):
return after_eof

def _process_frame(self, data):
data = self._decode(data)
command, remaing = data.split('\n', 1)
command, remaining = data.split(b'\n', 1)
command = self._decode(command)

raw_headers, remaing = remaing.split('\n\n', 1)
raw_headers, remaining = remaining.split(b'\n\n', 1)
raw_headers = self._decode(raw_headers)
headers = dict([l.split(':', 1) for l in raw_headers.split('\n')])
body = remaing.encode('latin-1') if remaing else None

body = None

# Only SEND, MESSAGE and ERROR frames can have body
if remaining and command in ('SEND', 'MESSAGE', 'ERROR'):
if 'content-length' in headers:
body = remaining[:int(headers['content-length'])]
else:
body = remaining

self._frames_ready.append(Frame(command, headers=headers, body=body))

Expand Down
47 changes: 43 additions & 4 deletions tests/test_recv_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,23 @@ def test_single_packet(self):

self.assertEqual(self.protocol._pending_parts, [])

def test_parcial_packet(self):
def test_no_body_command_packet(self):
self.protocol.feed_data(
b'CONNECT\n'
b'accept-version:1.0\n\n'
b'Hey dude\x00',
)

frames = self.protocol.pop_frames()

self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].command, u'CONNECT')
self.assertEqual(frames[0].headers, {u'accept-version': u'1.0'})
self.assertEqual(frames[0].body, None)

self.assertEqual(self.protocol._pending_parts, [])

def test_partial_packet(self):
stream_data = (
b'CONNECT\n',
b'accept-version:1.0\n\n\x00',
Expand All @@ -75,7 +91,7 @@ def test_parcial_packet(self):
self.assertEqual(frames[0].headers, {u'accept-version': u'1.0'})
self.assertEqual(frames[0].body, None)

def test_multi_parcial_packet1(self):
def test_multi_partial_packet1(self):
stream_data = (
b'CONNECT\n',
b'accept-version:1.0\n\n\x00\n',
Expand Down Expand Up @@ -103,7 +119,30 @@ def test_multi_parcial_packet1(self):

self.assertEqual(self.protocol._pending_parts, [])

def test_multi_parcial_packet2(self):
def test_read_content_by_length(self):
stream_data = (
b'ERROR\n',
b'header:1.0\n',
b'content-length:3\n\n'
b'Hey dude\x00\n',
)

for data in stream_data:
self.protocol.feed_data(data)

frames = self.protocol.pop_frames()
self.assertEqual(len(frames), 2)

self.assertEqual(frames[0].command, u'ERROR')
self.assertEqual(frames[0].headers, {u'header': u'1.0',
u'content-length': u'3'})
self.assertEqual(frames[0].body.decode(), u'Hey')

self.assertEqual(frames[1].command, u'HEARTBEAT')

self.assertEqual(self.protocol._pending_parts, [])

def test_multi_partial_packet2(self):
stream_data = (
b'CONNECTED\n'
b'version:1.0\n\n',
Expand Down Expand Up @@ -132,7 +171,7 @@ def test_multi_parcial_packet2(self):

self.assertEqual(self.protocol._pending_parts, [])

def test_multi_parcial_packet_with_utf8(self):
def test_multi_partial_packet_with_utf8(self):
stream_data = (
b'CONNECTED\n'
b'accept-version:1.0\n\n',
Expand Down

0 comments on commit 93beebd

Please sign in to comment.