Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
- Updated protocol handlers to more reliably remove active waiters when task cancellation occurs
- Fixed checks where expecting a KeyError when it should be checking if not None
- Updated next_packet_id property to correctly check if there are any packet_ids available. Avoids infinite loop if all packet ids are used.
  • Loading branch information
pazzarpj authored and FlorianLudwig committed Feb 27, 2023
1 parent 1a2812c commit d0eb64d
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 35 deletions.
35 changes: 20 additions & 15 deletions amqtt/mqtt/protocol/client_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,18 @@ async def mqtt_subscribe(self, topics, packet_id):
# Wait for SUBACK is received
waiter = futures.Future()
self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter
return_codes = await waiter

del self._subscriptions_waiter[subscribe.variable_header.packet_id]
try:
return_codes = await waiter
finally:
del self._subscriptions_waiter[subscribe.variable_header.packet_id]
return return_codes

async def handle_suback(self, suback: SubackPacket):
packet_id = suback.variable_header.packet_id
try:
waiter = self._subscriptions_waiter.get(packet_id)
waiter = self._subscriptions_waiter.get(packet_id)
if waiter is not None:
waiter.set_result(suback.payload.return_codes)
except KeyError:
else:
self.logger.warning(
"Received SUBACK for unknown pending subscription with Id: %s"
% packet_id
Expand All @@ -132,15 +133,17 @@ async def mqtt_unsubscribe(self, topics, packet_id):
await self._send_packet(unsubscribe)
waiter = futures.Future()
self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter
await waiter
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
try:
await waiter
finally:
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]

async def handle_unsuback(self, unsuback: UnsubackPacket):
packet_id = unsuback.variable_header.packet_id
try:
waiter = self._unsubscriptions_waiter.get(packet_id)
waiter = self._unsubscriptions_waiter.get(packet_id)
if waiter is not None:
waiter.set_result(None)
except KeyError:
else:
self.logger.warning(
"Received UNSUBACK for unknown pending subscription with Id: %s"
% packet_id
Expand All @@ -152,10 +155,12 @@ async def mqtt_disconnect(self):

async def mqtt_ping(self):
ping_packet = PingReqPacket()
await self._send_packet(ping_packet)
resp = await self._pingresp_queue.get()
if self._ping_task:
self._ping_task = None
try:
await self._send_packet(ping_packet)
resp = await self._pingresp_queue.get()
finally:
if self._ping_task:
self._ping_task = None
return resp

async def handle_pingresp(self, pingresp: PingRespPacket):
Expand Down
33 changes: 19 additions & 14 deletions amqtt/mqtt/protocol/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,13 @@ async def _handle_qos1_message_flow(self, app_message):
# Wait for puback
waiter = asyncio.Future()
self._puback_waiters[app_message.packet_id] = waiter
await waiter
del self._puback_waiters[app_message.packet_id]
app_message.puback_packet = waiter.result()

# Discard inflight message
del self.session.inflight_out[app_message.packet_id]
try:
await waiter
app_message.puback_packet = waiter.result()
finally:
self._puback_waiters.pop(app_message.packet_id, None)
# Discard inflight message
self.session.inflight_out.pop(app_message.packet_id, None)
elif app_message.direction == INCOMING:
# Initiate delivery
self.logger.debug("Add message to delivery")
Expand Down Expand Up @@ -351,21 +352,25 @@ async def _handle_qos2_message_flow(self, app_message):
raise AMQTTException(message)
waiter = asyncio.Future()
self._pubrec_waiters[app_message.packet_id] = waiter
await waiter
del self._pubrec_waiters[app_message.packet_id]
app_message.pubrec_packet = waiter.result()
try:
await waiter
app_message.pubrec_packet = waiter.result()
finally:
self._pubrec_waiters.pop(app_message.packet_id, None)
self.session.inflight_out.pop(app_message.packet_id, None)
if not app_message.pubcomp_packet:
# Send pubrel
app_message.pubrel_packet = PubrelPacket.build(app_message.packet_id)
await self._send_packet(app_message.pubrel_packet)
# Wait for PUBCOMP
waiter = asyncio.Future()
self._pubcomp_waiters[app_message.packet_id] = waiter
await waiter
del self._pubcomp_waiters[app_message.packet_id]
app_message.pubcomp_packet = waiter.result()
# Discard inflight message
del self.session.inflight_out[app_message.packet_id]
try:
await waiter
app_message.pubcomp_packet = waiter.result()
finally:
self._pubcomp_waiters.pop(app_message.packet_id, None)
self.session.inflight_out.pop(app_message.packet_id, None)
elif app_message.direction == INCOMING:
self.session.inflight_in[app_message.packet_id] = app_message
# Send pubrec
Expand Down
11 changes: 5 additions & 6 deletions amqtt/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,15 @@ def _init_states(self):

@property
def next_packet_id(self):
self._packet_id += 1
if self._packet_id > 65535:
self._packet_id = 1
self._packet_id = (self._packet_id % 65535) + 1
limit = self._packet_id
while (
self._packet_id in self.inflight_in or self._packet_id in self.inflight_out
):
self._packet_id += 1
if self._packet_id > 65535:
self._packet_id = (self._packet_id % 65535) + 1
if self._packet_id == limit:
raise AMQTTException(
"More than 65525 messages pending. No free packet ID"
"More than 65535 messages pending. No free packet ID"
)

return self._packet_id
Expand Down
77 changes: 77 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,80 @@ async def test_deliver_timeout():
await client.unsubscribe(["$SYS/broker/uptime"])
await client.disconnect()
await broker.shutdown()


@pytest.mark.asyncio
async def test_cancel_publish_qos1():
"""
Tests that timeouts on published messages will clean up in flight messages
"""
data = b"data"
broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins")
await broker.start()
client_pub = MQTTClient()
await client_pub.connect("mqtt:https://127.0.0.1/")
assert client_pub.session.inflight_out_count == 0
fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_1))
assert len(client_pub._handler._puback_waiters) == 0
while len(client_pub._handler._puback_waiters) == 0 or fut.done():
await asyncio.sleep(0)
assert len(client_pub._handler._puback_waiters) == 1
assert client_pub.session.inflight_out_count == 1
fut.cancel()
await asyncio.wait([fut])
assert len(client_pub._handler._puback_waiters) == 0
assert client_pub.session.inflight_out_count == 0
await client_pub.disconnect()
await broker.shutdown()


@pytest.mark.asyncio
async def test_cancel_publish_qos2_pubrec():
"""
Tests that timeouts on published messages will clean up in flight messages
"""
data = b"data"
broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins")
await broker.start()
client_pub = MQTTClient()
await client_pub.connect("mqtt:https://127.0.0.1/")
assert client_pub.session.inflight_out_count == 0
fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_2))
assert len(client_pub._handler._pubrec_waiters) == 0
while (
len(client_pub._handler._pubrec_waiters) == 0 or fut.done() or fut.cancelled()
):
await asyncio.sleep(0)
assert len(client_pub._handler._pubrec_waiters) == 1
assert client_pub.session.inflight_out_count == 1
fut.cancel()
await asyncio.sleep(1)
await asyncio.wait([fut])
assert len(client_pub._handler._pubrec_waiters) == 0
assert client_pub.session.inflight_out_count == 0
await client_pub.disconnect()
await broker.shutdown()


@pytest.mark.asyncio
async def test_cancel_publish_qos2_pubcomp():
"""
Tests that timeouts on published messages will clean up in flight messages
"""
data = b"data"
broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins")
await broker.start()
client_pub = MQTTClient()
await client_pub.connect("mqtt:https://127.0.0.1/")
assert client_pub.session.inflight_out_count == 0
fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_2))
assert len(client_pub._handler._pubcomp_waiters) == 0
while len(client_pub._handler._pubcomp_waiters) == 0 or fut.done():
await asyncio.sleep(0)
assert len(client_pub._handler._pubcomp_waiters) == 1
fut.cancel()
await asyncio.wait([fut])
assert len(client_pub._handler._pubcomp_waiters) == 0
assert client_pub.session.inflight_out_count == 0
await client_pub.disconnect()
await broker.shutdown()

0 comments on commit d0eb64d

Please sign in to comment.