diff --git a/bellows/__init__.py b/bellows/__init__.py index 0386d076..62768259 100644 --- a/bellows/__init__.py +++ b/bellows/__init__.py @@ -1,5 +1,5 @@ MAJOR_VERSION = 0 MINOR_VERSION = 35 -PATCH_VERSION = "2" +PATCH_VERSION = "3" __short_version__ = f"{MAJOR_VERSION}.{MINOR_VERSION}" __version__ = f"{__short_version__}.{PATCH_VERSION}" diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index 0e5a36a9..25378b2e 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -3,10 +3,12 @@ from __future__ import annotations import asyncio +import collections +import contextlib import functools import logging import sys -from typing import Any, Callable +from typing import Any, Callable, Generator import urllib.parse if sys.version_info[:2] < (3, 11): @@ -55,6 +57,42 @@ def __init__(self, device_config: dict): self._gw = None self._protocol = None + self._stack_status_listeners: collections.defaultdict[ + t.EmberStatus, list[asyncio.Future] + ] = collections.defaultdict(list) + + self.add_callback(self.stack_status_callback) + + def stack_status_callback(self, frame_name: str, args: list[Any]) -> None: + """Callback for `stackStatusHandler` messages.""" + if frame_name != "stackStatusHandler": + return + + status = args[0] + + for listener in self._stack_status_listeners[status]: + listener.set_result(status) + + @contextlib.contextmanager + def wait_for_stack_status(self, status: t.EmberStatus) -> Generator[asyncio.Future]: + """Waits for a `stackStatusHandler` to come in with the provided status.""" + listeners = self._stack_status_listeners[status] + + future = asyncio.get_running_loop().create_future() + + @future.add_done_callback + def maybe_remove(_): + with contextlib.suppress(ValueError): + listeners.remove(future) + + listeners.append(future) + + try: + yield future + finally: + with contextlib.suppress(ValueError): + listeners.remove(future) + @classmethod async def probe(cls, device_config: dict) -> bool | dict[str, int | str | bool]: """Probe port for the device presence.""" @@ -221,27 +259,17 @@ def cb(frame_name, response): 0, ) - async def leaveNetwork(self, timeout: float | int = NETWORK_OPS_TIMEOUT) -> list: + async def leaveNetwork(self, timeout: float | int = NETWORK_OPS_TIMEOUT) -> None: """Send leaveNetwork command and wait for stackStatusHandler frame.""" stack_status = asyncio.Future() - def cb(frame_name: str, response: list) -> None: - if ( - frame_name == "stackStatusHandler" - and response[0] == t.EmberStatus.NETWORK_DOWN - ): - stack_status.set_result(response) - - cb_id = self.add_callback(cb) - try: + with self.wait_for_stack_status(t.EmberStatus.NETWORK_DOWN) as stack_status: (status,) = await self._command("leaveNetwork") if status != t.EmberStatus.SUCCESS: raise EzspError(f"failed to leave network: {status.name}") async with asyncio_timeout(timeout): - return await stack_status - finally: - self.remove_callback(cb_id) + await stack_status def connection_lost(self, exc): """Lost serial connection.""" @@ -254,7 +282,7 @@ def connection_lost(self, exc): def enter_failed_state(self, error): """UART received error frame.""" - if self._callbacks: + if len(self._callbacks) > 1: LOGGER.error("NCP entered failed state. Requesting APP controller restart") self.close() self.handle_callback("_reset_controller_application", (error,)) @@ -269,28 +297,15 @@ def __getattr__(self, name: str) -> Callable: return functools.partial(self._command, name) - async def formNetwork(self, parameters): # noqa: N802 - fut = asyncio.Future() - - def cb(frame_name, response): - nonlocal fut - if frame_name == "stackStatusHandler": - fut.set_result(response) - - cb_id = self.add_callback(cb) - - try: + async def formNetwork(self, parameters: t.EmberNetworkParameters) -> None: + with self.wait_for_stack_status(t.EmberStatus.NETWORK_UP) as stack_status: v = await self._command("formNetwork", parameters) - if v[0] != self.types.EmberStatus.SUCCESS: - raise Exception(f"Failure forming network: {v}") - v = await fut - if v[0] != self.types.EmberStatus.NETWORK_UP: - raise Exception(f"Failure forming network: {v}") + if v[0] != self.types.EmberStatus.SUCCESS: + raise zigpy.exceptions.FormationFailure(f"Failure forming network: {v}") - return v - finally: - self.remove_callback(cb_id) + async with asyncio_timeout(NETWORK_OPS_TIMEOUT): + await stack_status def frame_received(self, data: bytes) -> None: """Handle a received EZSP frame diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 007b8798..b5546949 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -15,7 +15,7 @@ import zigpy.config import zigpy.device import zigpy.endpoint -from zigpy.exceptions import FormationFailure, NetworkNotFormed +from zigpy.exceptions import NetworkNotFormed import zigpy.state import zigpy.types import zigpy.util @@ -54,6 +54,7 @@ EZSP_MULTICAST_NON_MEMBER_RADIUS = 3 MFG_ID_RESET_DELAY = 180 RESET_ATTEMPT_BACKOFF_TIME = 5 +NETWORK_UP_TIMEOUT_S = 10 WATCHDOG_WAKE_PERIOD = 10 IEEE_PREFIX_MFG_ID = { "04:CF:8C": 0x115F, # Xiaomi @@ -147,13 +148,18 @@ async def _ensure_network_running(self) -> bool: if state == self._ezsp.types.EmberNetworkStatus.JOINED_NETWORK: return False - (init_status,) = await self._ezsp.networkInit() - if init_status == t.EmberStatus.SUCCESS: - return True - elif init_status == t.EmberStatus.NOT_JOINED: - raise NetworkNotFormed("Node is not part of a network") - else: - raise ControllerError(f"Failed to initialize network: {init_status!r}") + with self._ezsp.wait_for_stack_status(t.EmberStatus.NETWORK_UP) as stack_status: + (init_status,) = await self._ezsp.networkInit() + + if init_status == t.EmberStatus.NOT_JOINED: + raise NetworkNotFormed("Node is not part of a network") + elif init_status != t.EmberStatus.SUCCESS: + raise ControllerError(f"Failed to initialize network: {init_status!r}") + + async with asyncio_timeout(NETWORK_UP_TIMEOUT_S): + await stack_status + + return True async def start_network(self): ezsp = self._ezsp @@ -387,10 +393,6 @@ async def write_network_info( (status,) = await ezsp.setInitialSecurityState(initial_security_state) assert status == t.EmberStatus.SUCCESS - # Clear the key table - (status,) = await ezsp.clearKeyTable() - assert status == t.EmberStatus.SUCCESS - # Write APS link keys for key in network_info.key_table: ember_key = util.zigpy_key_to_ezsp_key(key, ezsp) @@ -429,7 +431,6 @@ async def write_network_info( parameters.channels = t.Channels(network_info.channel_mask) await ezsp.formNetwork(parameters) - await ezsp.setValue(ezsp.types.EzspValueId.VALUE_STACK_TOKEN_WRITING, 1) async def reset_network_info(self): # The network must be running before we can leave it @@ -438,13 +439,11 @@ async def reset_network_info(self): except zigpy.exceptions.NetworkNotFormed: return - try: - (status,) = await self._ezsp.leaveNetwork() - except bellows.exception.EzspError: - pass - else: - if status != t.EmberStatus.NETWORK_DOWN: - raise FormationFailure("Couldn't leave network") + await self._ezsp.leaveNetwork() + + # Clear the key table + (status,) = await self._ezsp.clearKeyTable() + assert status == t.EmberStatus.SUCCESS async def disconnect(self): # TODO: how do you shut down the stack? @@ -685,14 +684,19 @@ async def energy_scan( all_results = {} for _ in range(count): - results = await self._ezsp.startScan( - t.EzspNetworkScanType.ENERGY_SCAN, - channels, - duration_exp, - ) + channels_to_scan = set(channels) + + # XXX: RCP firmware sometimes performs a partial scan and returns early + while channels_to_scan: + results = await self._ezsp.startScan( + t.EzspNetworkScanType.ENERGY_SCAN, + t.Channels.from_channel_list(channels_to_scan), + duration_exp, + ) - for channel, rssi in results: - all_results.setdefault(channel, []).append(rssi) + for channel, rssi in results: + all_results.setdefault(channel, []).append(rssi) + channels_to_scan.remove(channel) # Remap RSSI to Energy return { diff --git a/tests/test_application.py b/tests/test_application.py index 5ccd17f6..35dd617f 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -20,6 +20,7 @@ import bellows.uart as uart import bellows.zigbee.application import bellows.zigbee.device +from bellows.zigbee.util import map_rssi_to_energy from .async_mock import AsyncMock, MagicMock, PropertyMock, patch, sentinel @@ -35,21 +36,25 @@ @pytest.fixture def ezsp_mock(): """EZSP fixture""" - ezsp = MagicMock() - ezsp.ezsp_version = 7 - ezsp.setManufacturerCode = AsyncMock() - ezsp.set_source_route = AsyncMock(return_value=[t.EmberStatus.SUCCESS]) - ezsp.addTransientLinkKey = AsyncMock(return_value=[0]) - ezsp.readCounters = AsyncMock(return_value=[[0] * 10]) - ezsp.readAndClearCounters = AsyncMock(return_value=[[0] * 10]) - ezsp.setPolicy = AsyncMock(return_value=[0]) - ezsp.get_board_info = AsyncMock( + mock_ezsp = MagicMock(spec=ezsp.EZSP) + mock_ezsp.ezsp_version = 7 + mock_ezsp.setManufacturerCode = AsyncMock() + mock_ezsp.set_source_route = AsyncMock(return_value=[t.EmberStatus.SUCCESS]) + mock_ezsp.addTransientLinkKey = AsyncMock(return_value=[0]) + mock_ezsp.readCounters = AsyncMock(return_value=[[0] * 10]) + mock_ezsp.readAndClearCounters = AsyncMock(return_value=[[0] * 10]) + mock_ezsp.setPolicy = AsyncMock(return_value=[0]) + mock_ezsp.get_board_info = AsyncMock( return_value=("Mock Manufacturer", "Mock board", "Mock version") ) - type(ezsp).types = ezsp_t7 - type(ezsp).is_ezsp_running = PropertyMock(return_value=True) + mock_ezsp.wait_for_stack_status.return_value.__enter__ = AsyncMock( + return_value=t.EmberStatus.NETWORK_UP + ) + + type(mock_ezsp).types = ezsp_t7 + type(mock_ezsp).is_ezsp_running = PropertyMock(return_value=True) - return ezsp + return mock_ezsp @pytest.fixture @@ -559,6 +564,7 @@ def test_sequence(app): def test_permit_ncp(app): + app._ezsp.permitJoining = AsyncMock() app.permit_ncp(60) assert app._ezsp.permitJoining.call_count == 1 @@ -1474,6 +1480,8 @@ def test_handle_id_conflict(app, ieee): async def test_handle_no_such_device(app, ieee): """Test handling of an unknown device IEEE lookup.""" + app._ezsp.lookupEui64ByNodeId = AsyncMock() + p1 = patch.object( app._ezsp, "lookupEui64ByNodeId", @@ -1582,6 +1590,11 @@ async def test_set_mfg_id(ieee, expected_mfg_id, app, ezsp_mock): async def test_ensure_network_running_joined(app): ezsp = app._ezsp + + # Make initialization take two attempts + ezsp.networkInit = AsyncMock( + side_effect=[(t.EmberStatus.NETWORK_BUSY,), (t.EmberStatus.SUCCESS,)] + ) ezsp.networkState = AsyncMock( return_value=[ezsp.types.EmberNetworkStatus.JOINED_NETWORK] ) @@ -1733,5 +1746,29 @@ async def test_energy_scanning(app, scan_results): count=1, ) + assert len(app._ezsp.startScan.mock_calls) == 1 + assert set(results.keys()) == set(t.Channels.ALL_CHANNELS) assert all(0 <= v <= 255 for v in results.values()) + + +async def test_energy_scanning_partial(app): + app._ezsp.startScan = AsyncMock( + side_effect=[ + [(11, 11), (12, 12), (13, 13), (14, 14), (15, 15), (16, 16)], + [(17, 17)], + [], + [(18, 18), (19, 19), (20, 20)], + [(21, 21), (22, 22), (23, 23), (24, 24), (25, 25), (26, 26)], + ] + ) + + results = await app.energy_scan( + channels=t.Channels.ALL_CHANNELS, + duration_exp=2, + count=1, + ) + + assert len(app._ezsp.startScan.mock_calls) == 5 + assert set(results.keys()) == set(t.Channels.ALL_CHANNELS) + assert results == {c: map_rssi_to_energy(c) for c in range(11, 26 + 1)} diff --git a/tests/test_application_network_state.py b/tests/test_application_network_state.py index 108df3b7..4035f90b 100644 --- a/tests/test_application_network_state.py +++ b/tests/test_application_network_state.py @@ -372,23 +372,6 @@ def _mock_app_for_write(app, network_info, node_info, ezsp_ver=None): ezsp.can_write_custom_eui64 = AsyncMock(return_value=True) -async def test_write_network_info_failed_leave1(app, network_info, node_info): - _mock_app_for_write(app, network_info, node_info) - - app._ezsp.leaveNetwork.return_value = [t.EmberStatus.BAD_ARGUMENT] - - with pytest.raises(zigpy.exceptions.FormationFailure): - await app.write_network_info(network_info=network_info, node_info=node_info) - - -async def test_write_network_info_failed_leave2(app, network_info, node_info): - _mock_app_for_write(app, network_info, node_info) - - app._ezsp.leaveNetwork.side_effect = EzspError("failed to leave network") - - await app.write_network_info(network_info=network_info, node_info=node_info) - - @pytest.mark.parametrize("ezsp_ver", [4, 7]) async def test_write_network_info(app, network_info, node_info, ezsp_ver): _mock_app_for_write(app, network_info, node_info, ezsp_ver) diff --git a/tests/test_ezsp.py b/tests/test_ezsp.py index 5616332b..b0700547 100644 --- a/tests/test_ezsp.py +++ b/tests/test_ezsp.py @@ -1,5 +1,6 @@ import asyncio import functools +import sys import pytest @@ -7,6 +8,11 @@ from bellows.exception import EzspError import bellows.ezsp.v4.types as t +if sys.version_info[:2] < (3, 11): + from async_timeout import timeout as asyncio_timeout # pragma: no cover +else: + from asyncio import timeout as asyncio_timeout # pragma: no cover + from .async_mock import AsyncMock, MagicMock, call, patch, sentinel DEVICE_CONFIG = { @@ -48,7 +54,7 @@ async def test_reset(ezsp_f): assert ezsp_f._gw.reset.call_count == 1 assert ezsp_f.start_ezsp.call_count == 1 assert ezsp_f.stop_ezsp.call_count == 1 - assert len(ezsp_f._callbacks) == 0 + assert len(ezsp_f._callbacks) == 1 def test_close(ezsp_f): @@ -487,8 +493,7 @@ async def _mock_cmd(*args, **kwargs): with patch.object(ezsp_f, "_command", new_callable=AsyncMock) as cmd_mock: cmd_mock.side_effect = _mock_cmd - (status,) = await ezsp_f.leaveNetwork(timeout=0.01) - assert status == t.EmberStatus.NETWORK_DOWN + await ezsp_f.leaveNetwork(timeout=0.01) @pytest.mark.parametrize( @@ -571,3 +576,26 @@ async def wait_forever(*args, **kwargs): assert prot_handler_mock.await_count == 1 assert src_mock.call_count == 0 assert src_mock.await_count == 0 + + +async def test_wait_for_stack_status(ezsp_f): + assert not ezsp_f._stack_status_listeners[t.EmberStatus.NETWORK_DOWN] + + # Cancellation clears handlers + with ezsp_f.wait_for_stack_status(t.EmberStatus.NETWORK_DOWN) as stack_status: + with pytest.raises(asyncio.TimeoutError): + async with asyncio_timeout(0.1): + assert ezsp_f._stack_status_listeners[t.EmberStatus.NETWORK_DOWN] + await stack_status + + assert not ezsp_f._stack_status_listeners[t.EmberStatus.NETWORK_DOWN] + + # Receiving multiple also works + with ezsp_f.wait_for_stack_status(t.EmberStatus.NETWORK_DOWN) as stack_status: + ezsp_f.handle_callback("stackStatusHandler", [t.EmberStatus.NETWORK_UP]) + ezsp_f.handle_callback("stackStatusHandler", [t.EmberStatus.NETWORK_DOWN]) + ezsp_f.handle_callback("stackStatusHandler", [t.EmberStatus.NETWORK_DOWN]) + + await stack_status + + assert not ezsp_f._stack_status_listeners[t.EmberStatus.NETWORK_DOWN]