From 2fe6e862bc45faa6e667f0e5d2ad5d4c93efc60b Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 5 Feb 2024 15:45:01 -0500 Subject: [PATCH] Properly handle NAK frames and implement retries (#610) * Retry transmits * Handle retransmitted NCP frames * Use named constants * Drop log level * Use `ConnectionResetError` instead of `RuntimeError` * Implement dynamic ACK timeout * Fix _rec_seq in tests * Reduce UART concurrency to 1 * Add unit tests --------- Co-authored-by: David Mulcahey --- bellows/ezsp/__init__.py | 2 +- bellows/ezsp/protocol.py | 2 +- bellows/uart.py | 88 +++++++++++++++++++++++++++---- tests/test_uart.py | 111 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 192 insertions(+), 11 deletions(-) diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index fa4aab0b..1872a740 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -41,7 +41,7 @@ NETWORK_OPS_TIMEOUT = 10 NETWORK_COORDINATOR_STARTUP_RESET_WAIT = 1 -MAX_COMMAND_CONCURRENCY = 4 +MAX_COMMAND_CONCURRENCY = 1 class EZSP: diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index c7b08054..9518ba89 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -17,7 +17,7 @@ from bellows.typing import GatewayType LOGGER = logging.getLogger(__name__) -EZSP_CMD_TIMEOUT = 5 +EZSP_CMD_TIMEOUT = 6 # Sum of all ASH retry timeouts: 0.4 + 0.8 + 1.6 + 3.2 class ProtocolHandler(abc.ABC): diff --git a/bellows/uart.py b/bellows/uart.py index 73bad18b..bf068a3e 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -2,6 +2,7 @@ import binascii import logging import sys +import time if sys.version_info[:2] < (3, 11): from async_timeout import timeout as asyncio_timeout # pragma: no cover @@ -17,6 +18,12 @@ LOGGER = logging.getLogger(__name__) RESET_TIMEOUT = 5 +ASH_ACK_RETRIES = 4 + +ASH_RX_ACK_INIT = 1.6 +ASH_RX_ACK_MIN = 0.4 +ASH_RX_ACK_MAX = 3.2 + class Gateway(asyncio.Protocol): FLAG = b"\x7E" # Marks end of frame @@ -47,6 +54,7 @@ def __init__(self, application, connected_future=None, connection_done_future=No self._connection_done_future = connection_done_future self._send_task = None + self._ack_timeout = ASH_RX_ACK_INIT def connection_made(self, transport): """Callback when the uart is connected""" @@ -118,10 +126,18 @@ def data_frame_received(self, data): """Data frame receive handler""" LOGGER.debug("Data frame: %s", binascii.hexlify(data)) seq = (data[0] & 0b01110000) >> 4 - self._rec_seq = (seq + 1) % 8 - self.write(self._ack_frame()) - self._handle_ack(data[0]) - self._application.frame_received(self._randomize(data[1:-3])) + re_tx = (data[0] & 0b00001000) >> 3 + + if seq == self._rec_seq: + self._rec_seq = (seq + 1) % 8 + self.write(self._ack_frame()) + + self._handle_ack(data[0]) + self._application.frame_received(self._randomize(data[1:-3])) + elif re_tx: + self.write(self._ack_frame()) + else: + self.write(self._nak_frame()) def ack_frame_received(self, data): """Acknowledgement frame receive handler""" @@ -268,13 +284,67 @@ async def _send_loop(self): if item is self.Terminator: break data, seq = item - success = False - rxmit = 0 - while not success: + + for attempt in range(ASH_ACK_RETRIES + 1): self._pending = (seq, asyncio.get_event_loop().create_future()) + + send_time = time.monotonic() + rxmit = attempt > 0 self.write(self._data_frame(data, seq, rxmit)) - rxmit = 1 - success = await self._pending[1] + + try: + async with asyncio_timeout(self._ack_timeout): + success = await self._pending[1] + except asyncio.TimeoutError: + success = None + LOGGER.debug( + "Frame %s (seq %s) timed out on attempt %d, retrying", + data, + seq, + attempt, + ) + else: + if success: + break + + LOGGER.debug( + "Frame %s (seq %s) failed to transmit on attempt %d, retrying", + data, + seq, + attempt, + ) + finally: + delta = time.monotonic() - send_time + + if success is not None: + new_ack_timeout = max( + ASH_RX_ACK_MIN, + min( + ASH_RX_ACK_MAX, + (7 / 8) * self._ack_timeout + 0.5 * delta, + ), + ) + else: + new_ack_timeout = max( + ASH_RX_ACK_MIN, min(ASH_RX_ACK_MAX, 2 * self._ack_timeout) + ) + + if abs(self._ack_timeout - new_ack_timeout) > 0.01: + LOGGER.debug( + "Adjusting ACK timeout from %.2f to %.2f", + self._ack_timeout, + new_ack_timeout, + ) + + self._ack_timeout = new_ack_timeout + self._pending = (-1, None) + else: + self.connection_lost( + ConnectionResetError( + f"Failed to transmit ASH frame after {ASH_ACK_RETRIES} retries" + ) + ) + return def _handle_ack(self, control): """Handle an acknowledgement frame""" diff --git a/tests/test_uart.py b/tests/test_uart.py index 73f43e46..1dc5574c 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -194,6 +194,7 @@ def test_substitute_received(gw): def test_partial_data_received(gw): gw.write = MagicMock() + gw._rec_seq = 5 gw.data_received(b"\x54\x79\xa1\xb0") gw.data_received(b"\x50\xf2\x6e\x7e") assert gw.write.call_count == 1 @@ -209,6 +210,7 @@ def test_crc_error(gw): def test_crc_error_and_valid_frame(gw): gw.write = MagicMock() + gw._rec_seq = 5 gw.data_received( b"L\xa1\x8e\x03\xcd\x07\xb9Y\xfbG%\xae\xbd~\x54\x79\xa1\xb0\x50\xf2\x6e\x7e" ) @@ -218,6 +220,7 @@ def test_crc_error_and_valid_frame(gw): def test_data_frame_received(gw): gw.write = MagicMock() + gw._rec_seq = 5 gw.data_received(b"\x54\x79\xa1\xb0\x50\xf2\x6e\x7e") assert gw.write.call_count == 1 assert gw._application.frame_received.call_count == 1 @@ -416,3 +419,111 @@ async def test_wait_for_startup_reset_failure(gw): await asyncio.wait_for(gw.wait_for_startup_reset(), 0.01) assert gw._startup_reset_future is None + + +ASH_ACK_MIN = 0.01 + + +@patch("bellows.uart.ASH_RX_ACK_MIN", new=ASH_ACK_MIN * 2**0) +@patch("bellows.uart.ASH_RX_ACK_INIT", new=ASH_ACK_MIN * 2**2) +@patch("bellows.uart.ASH_RX_ACK_MAX", new=ASH_ACK_MIN * 2**3) +async def test_retry_success(): + app = MagicMock() + transport = MagicMock() + connected_future = asyncio.get_running_loop().create_future() + + gw = uart.Gateway(app, connected_future) + gw.connection_made(transport) + + old_timeout = gw._ack_timeout + gw.data(b"TX 1") + await asyncio.sleep(0) + + # Wait more than one ACK cycle to reply + assert len(transport.write.mock_calls) == 1 + await asyncio.sleep(ASH_ACK_MIN * 5) + + # The gateway has retried once by now + assert len(transport.write.mock_calls) == 2 + + gw.frame_received( + # ash.DataFrame(frm_num=0, re_tx=0, ack_num=1, ezsp_frame=b"RX 1").to_bytes() + bytes.fromhex("01107988654851") + ) + + # An ACK has been received and the pending frame has been acknowledged + await asyncio.sleep(0) + assert gw._pending == (-1, None) + + assert gw._ack_timeout > old_timeout + + gw.close() + + +@patch("bellows.uart.ASH_RX_ACK_MIN", new=ASH_ACK_MIN * 2**0) +@patch("bellows.uart.ASH_RX_ACK_INIT", new=ASH_ACK_MIN * 2**2) +@patch("bellows.uart.ASH_RX_ACK_MAX", new=ASH_ACK_MIN * 2**3) +async def test_retry_nak_then_success(): + app = MagicMock() + transport = MagicMock() + connected_future = asyncio.get_running_loop().create_future() + + gw = uart.Gateway(app, connected_future) + gw.connection_made(transport) + + old_timeout = gw._ack_timeout + gw.data(b"TX 1") + await asyncio.sleep(0) + assert len(transport.write.mock_calls) == 1 + + # Wait less than one ACK cycle so that we can NAK the frame during the RX window + await asyncio.sleep(ASH_ACK_MIN) + # NAK the frame + gw.frame_received( + # ash.NakFrame(res=0, ncp_ready=0, ack_num=0).to_bytes() + bytes.fromhex("a0541a") + ) + + # The gateway has retried once more, instantly + await asyncio.sleep(0) + assert len(transport.write.mock_calls) == 2 + + # Send a proper ACK + gw.frame_received( + # ash.AckFrame(res=0, ncp_ready=0, ack_num=1).to_bytes() + bytes.fromhex("816059") + ) + await asyncio.sleep(0) + assert gw._pending == (-1, None) + assert gw._ack_timeout < old_timeout + + gw.close() + + +@patch("bellows.uart.ASH_RX_ACK_MIN", new=ASH_ACK_MIN * 2**0) +@patch("bellows.uart.ASH_RX_ACK_INIT", new=ASH_ACK_MIN * 2**2) +@patch("bellows.uart.ASH_RX_ACK_MAX", new=ASH_ACK_MIN * 2**3) +async def test_retry_failure(): + app = MagicMock() + transport = MagicMock() + connected_future = asyncio.get_running_loop().create_future() + + gw = uart.Gateway(app, connected_future) + gw.connection_made(transport) + + old_timeout = gw._ack_timeout + gw.data(b"TX 1") + await asyncio.sleep(0) + + # Wait more than one ACK cycle to reply + assert len(transport.write.mock_calls) == 1 + await asyncio.sleep(ASH_ACK_MIN * 40) + + # The gateway has exhausted retries + assert len(transport.write.mock_calls) == 5 + + assert gw._pending == (-1, None) + assert gw._ack_timeout > old_timeout + assert gw._ack_timeout == ASH_ACK_MIN * 2**3 # max timeout + + gw.close()