From 1ac01edef70d49dec21694909ddd9e56d5fc07eb Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 6 Dec 2023 13:10:16 -0500 Subject: [PATCH] Remove UART thread (#598) --- bellows/config/__init__.py | 1 - bellows/ezsp/__init__.py | 6 +- bellows/thread.py | 122 ---------------------- bellows/uart.py | 22 +--- bellows/zigbee/application.py | 9 +- tests/test_thread.py | 188 ---------------------------------- tests/test_uart.py | 112 -------------------- 7 files changed, 6 insertions(+), 454 deletions(-) delete mode 100644 bellows/thread.py delete mode 100644 tests/test_thread.py diff --git a/bellows/config/__init__.py b/bellows/config/__init__.py index 29ffe647..1357fcaf 100644 --- a/bellows/config/__init__.py +++ b/bellows/config/__init__.py @@ -30,7 +30,6 @@ vol.Optional(CONF_EZSP_POLICIES, default={}): vol.Schema( {vol.Optional(str): int} ), - vol.Optional(CONF_USE_THREAD, default=True): cv_boolean, } ) diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index 9bac445c..2402cd79 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -124,7 +124,7 @@ async def startup_reset(self) -> None: async def initialize(cls, zigpy_config: dict) -> EZSP: """Return initialized EZSP instance.""" ezsp = cls(zigpy_config[conf.CONF_DEVICE]) - await ezsp.connect(use_thread=zigpy_config[conf.CONF_USE_THREAD]) + await ezsp.connect() try: await ezsp.startup_reset() @@ -134,9 +134,9 @@ async def initialize(cls, zigpy_config: dict) -> EZSP: return ezsp - async def connect(self, *, use_thread: bool = True) -> None: + async def connect(self) -> None: assert self._gw is None - self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread) + self._gw = await bellows.uart.connect(self._config, self) self._protocol = v4.EZSPv4(self.handle_callback, self._gw) async def reset(self): diff --git a/bellows/thread.py b/bellows/thread.py deleted file mode 100644 index 6d8c1309..00000000 --- a/bellows/thread.py +++ /dev/null @@ -1,122 +0,0 @@ -import asyncio -from concurrent.futures import ThreadPoolExecutor -import functools -import logging -import sys - -LOGGER = logging.getLogger(__name__) - - -class EventLoopThread: - """Run a parallel event loop in a separate thread.""" - - def __init__(self): - self.loop = None - self.thread_complete = None - - def run_coroutine_threadsafe(self, coroutine): - current_loop = asyncio.get_event_loop() - future = asyncio.run_coroutine_threadsafe(coroutine, self.loop) - return asyncio.wrap_future(future, loop=current_loop) - - def _thread_main(self, init_task): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - try: - self.loop.run_until_complete(init_task) - self.loop.run_forever() - finally: - self.loop.close() - self.loop = None - - async def start(self): - current_loop = asyncio.get_event_loop() - if self.loop is not None and not self.loop.is_closed(): - return - - executor_opts = {"max_workers": 1} - if sys.version_info[:2] >= (3, 6): - executor_opts["thread_name_prefix"] = __name__ - executor = ThreadPoolExecutor(**executor_opts) - - thread_started_future = current_loop.create_future() - - async def init_task(): - current_loop.call_soon_threadsafe(thread_started_future.set_result, None) - - # Use current loop so current loop has a reference to the long-running thread - # as one of its tasks - thread_complete = current_loop.run_in_executor( - executor, self._thread_main, init_task() - ) - self.thread_complete = thread_complete - current_loop.call_soon(executor.shutdown, False) - await thread_started_future - return thread_complete - - def force_stop(self): - if self.loop is None: - return - - def cancel_tasks_and_stop_loop(): - tasks = asyncio.all_tasks(loop=self.loop) - - for task in tasks: - self.loop.call_soon_threadsafe(task.cancel) - - gather = asyncio.gather(*tasks, return_exceptions=True) - gather.add_done_callback( - lambda _: self.loop.call_soon_threadsafe(self.loop.stop) - ) - - self.loop.call_soon_threadsafe(cancel_tasks_and_stop_loop) - - -class ThreadsafeProxy: - """Proxy class which enforces threadsafe non-blocking calls - This class can be used to wrap an object to ensure any calls - using that object's methods are done on a particular event loop - """ - - def __init__(self, obj, obj_loop): - self._obj = obj - self._obj_loop = obj_loop - - def __getattr__(self, name): - func = getattr(self._obj, name) - if not callable(func): - raise TypeError( - "Can only use ThreadsafeProxy with callable attributes: {}.{}".format( - self._obj.__class__.__name__, name - ) - ) - - def func_wrapper(*args, **kwargs): - loop = self._obj_loop - curr_loop = asyncio.get_running_loop() - call = functools.partial(func, *args, **kwargs) - if loop == curr_loop: - return call() - if loop.is_closed(): - # Disconnected - LOGGER.warning("Attempted to use a closed event loop") - return - if asyncio.iscoroutinefunction(func): - future = asyncio.run_coroutine_threadsafe(call(), loop) - return asyncio.wrap_future(future, loop=curr_loop) - else: - - def check_result_wrapper(): - result = call() - if result is not None: - raise TypeError( - ( - "ThreadsafeProxy can only wrap functions with no return" - "value \nUse an async method to return values: {}.{}" - ).format(self._obj.__class__.__name__, name) - ) - - loop.call_soon_threadsafe(check_result_wrapper) - - return func_wrapper diff --git a/bellows/uart.py b/bellows/uart.py index 73bad18b..2acbd1c8 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -11,7 +11,6 @@ import zigpy.config import zigpy.serial -from bellows.thread import EventLoopThread, ThreadsafeProxy import bellows.types as t LOGGER = logging.getLogger(__name__) @@ -364,7 +363,7 @@ def _unstuff(self, s): return out -async def _connect(config, application): +async def connect(config, application): loop = asyncio.get_event_loop() connection_future = loop.create_future() @@ -387,23 +386,4 @@ async def _connect(config, application): await connection_future - thread_safe_protocol = ThreadsafeProxy(protocol, loop) - return thread_safe_protocol, connection_done_future - - -async def connect(config, application, use_thread=True): - if use_thread: - application = ThreadsafeProxy(application, asyncio.get_event_loop()) - thread = EventLoopThread() - await thread.start() - try: - protocol, connection_done = await thread.run_coroutine_threadsafe( - _connect(config, application) - ) - except Exception: - thread.force_stop() - raise - connection_done.add_done_callback(lambda _: thread.force_stop()) - else: - protocol, _ = await _connect(config, application) return protocol diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 7c85e8ec..0b147283 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -24,12 +24,7 @@ import zigpy.zdo.types as zdo_t import bellows -from bellows.config import ( - CONF_EZSP_CONFIG, - CONF_EZSP_POLICIES, - CONF_USE_THREAD, - CONFIG_SCHEMA, -) +from bellows.config import CONF_EZSP_CONFIG, CONF_EZSP_POLICIES, CONFIG_SCHEMA from bellows.exception import ControllerError, EzspError, StackAlreadyRunning import bellows.ezsp from bellows.ezsp.v8.types.named import EmberDeviceUpdate @@ -138,7 +133,7 @@ async def _get_board_info(self) -> tuple[str, str, str] | tuple[None, None, None async def connect(self) -> None: ezsp = bellows.ezsp.EZSP(self.config[zigpy.config.CONF_DEVICE]) - await ezsp.connect(use_thread=self.config[CONF_USE_THREAD]) + await ezsp.connect() try: await ezsp.startup_reset() diff --git a/tests/test_thread.py b/tests/test_thread.py deleted file mode 100644 index a7c37723..00000000 --- a/tests/test_thread.py +++ /dev/null @@ -1,188 +0,0 @@ -import asyncio -import sys -import threading -from unittest import mock - -if sys.version_info[:2] < (3, 11): - from async_timeout import timeout as asyncio_timeout # pragma: no cover -else: - from asyncio import timeout as asyncio_timeout # pragma: no cover - -import pytest - -from bellows.thread import EventLoopThread, ThreadsafeProxy - - -async def test_thread_start(monkeypatch): - current_loop = asyncio.get_event_loop() - loopmock = mock.MagicMock() - - monkeypatch.setattr(asyncio, "new_event_loop", lambda: loopmock) - monkeypatch.setattr(asyncio, "set_event_loop", lambda loop: None) - - def mockrun(task): - future = asyncio.run_coroutine_threadsafe(task, loop=current_loop) - return future.result(1) - - loopmock.run_until_complete.side_effect = mockrun - thread = EventLoopThread() - thread_complete = await thread.start() - await thread_complete - - assert loopmock.run_until_complete.call_count == 1 - assert loopmock.run_forever.call_count == 1 - assert loopmock.close.call_count == 1 - - -class ExceptionCollector: - def __init__(self): - self.exceptions = [] - - def __call__(self, thread_loop, context): - exc = context.get("exception") or Exception(context["message"]) - self.exceptions.append(exc) - - -@pytest.fixture -async def thread(): - thread = EventLoopThread() - await thread.start() - thread.loop.call_soon_threadsafe( - thread.loop.set_exception_handler, ExceptionCollector() - ) - yield thread - thread.force_stop() - if thread.thread_complete is not None: - async with asyncio_timeout(1): - await thread.thread_complete - [t.join(1) for t in threading.enumerate() if "bellows" in t.name] - threads = [t for t in threading.enumerate() if "bellows" in t.name] - assert len(threads) == 0 - - -async def yield_other_thread(thread): - await thread.run_coroutine_threadsafe(asyncio.sleep(0)) - - exception_collector = thread.loop.get_exception_handler() - if exception_collector.exceptions: - raise exception_collector.exceptions[0] - - -async def test_thread_loop(thread): - async def test_coroutine(): - return mock.sentinel.result - - future = asyncio.run_coroutine_threadsafe(test_coroutine(), loop=thread.loop) - result = await asyncio.wrap_future(future, loop=asyncio.get_event_loop()) - assert result is mock.sentinel.result - - -async def test_thread_double_start(thread): - previous_loop = thread.loop - await thread.start() - if sys.version_info[:2] >= (3, 6): - threads = [t for t in threading.enumerate() if "bellows" in t.name] - assert len(threads) == 1 - assert thread.loop is previous_loop - - -async def test_thread_already_stopped(thread): - thread.force_stop() - thread.force_stop() - - -async def test_thread_run_coroutine_threadsafe(thread): - inner_loop = None - - async def test_coroutine(): - nonlocal inner_loop - inner_loop = asyncio.get_event_loop() - return mock.sentinel.result - - result = await thread.run_coroutine_threadsafe(test_coroutine()) - assert result is mock.sentinel.result - assert inner_loop is thread.loop - - -async def test_proxy_callback(thread): - obj = mock.MagicMock() - proxy = ThreadsafeProxy(obj, thread.loop) - obj.test.return_value = None - proxy.test() - await yield_other_thread(thread) - assert obj.test.call_count == 1 - - -async def test_proxy_async(thread): - obj = mock.MagicMock() - proxy = ThreadsafeProxy(obj, thread.loop) - call_count = 0 - - async def magic(): - nonlocal thread, call_count - assert asyncio.get_event_loop() == thread.loop - call_count += 1 - return mock.sentinel.result - - obj.test = magic - result = await proxy.test() - - assert call_count == 1 - assert result == mock.sentinel.result - - -async def test_proxy_bad_function(thread): - obj = mock.MagicMock() - proxy = ThreadsafeProxy(obj, thread.loop) - obj.test.return_value = mock.sentinel.value - - with pytest.raises(TypeError): - proxy.test() - await yield_other_thread(thread) - - -async def test_proxy_not_function(): - loop = asyncio.get_event_loop() - obj = mock.MagicMock() - proxy = ThreadsafeProxy(obj, loop) - obj.test = mock.sentinel.value - with pytest.raises(TypeError): - proxy.test - - -async def test_proxy_no_thread(): - loop = asyncio.get_event_loop() - obj = mock.MagicMock() - proxy = ThreadsafeProxy(obj, loop) - proxy.test() - assert obj.test.call_count == 1 - - -async def test_proxy_loop_closed(): - loop = asyncio.new_event_loop() - obj = mock.MagicMock() - proxy = ThreadsafeProxy(obj, loop) - loop.close() - proxy.test() - assert obj.test.call_count == 0 - - -async def test_thread_task_cancellation_after_stop(thread): - loop = asyncio.get_event_loop() - obj = mock.MagicMock() - - async def wait_forever(): - return await thread.loop.create_future() - - obj.wait_forever = wait_forever - - # Stop the thread while we're waiting - loop.call_later(0.1, thread.force_stop) - - proxy = ThreadsafeProxy(obj, thread.loop) - - # The cancellation should propagate to the outer event loop - with pytest.raises(asyncio.CancelledError): - # This will stall forever without the patch - async with asyncio_timeout(1): - await proxy.wait_forever() diff --git a/tests/test_uart.py b/tests/test_uart.py index 73f43e46..97d3b4ae 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -1,5 +1,4 @@ import asyncio -import threading import pytest import serial_asyncio @@ -30,113 +29,9 @@ async def mockconnect(loop, protocol_factory, **kwargs): } ), appmock, - use_thread=False, ) - - threads = [t for t in threading.enumerate() if "bellows" in t.name] - assert len(threads) == 0 - gw.close() - - -async def test_connect_threaded(monkeypatch): - appmock = MagicMock() - transport = MagicMock() - - async def mockconnect(loop, protocol_factory, **kwargs): - protocol = protocol_factory() - loop.call_soon(protocol.connection_made, transport) - return None, protocol - - monkeypatch.setattr(serial_asyncio, "create_serial_connection", mockconnect) - - def on_transport_close(): - gw.connection_lost(None) - - transport.close.side_effect = on_transport_close - gw = await uart.connect( - conf.SCHEMA_DEVICE( - {conf.CONF_DEVICE_PATH: "/dev/serial", conf.CONF_DEVICE_BAUDRATE: 115200} - ), - appmock, - ) - - # Need to close to release thread gw.close() - # Ensure all threads are cleaned up - [t.join(1) for t in threading.enumerate() if "bellows" in t.name] - threads = [t for t in threading.enumerate() if "bellows" in t.name] - assert len(threads) == 0 - - -async def test_connect_threaded_failure(monkeypatch): - appmock = MagicMock() - transport = MagicMock() - - mockconnect = AsyncMock() - mockconnect.side_effect = OSError - - monkeypatch.setattr(serial_asyncio, "create_serial_connection", mockconnect) - - def on_transport_close(): - gw.connection_lost(None) - - transport.close.side_effect = on_transport_close - with pytest.raises(OSError): - gw = await uart.connect( - conf.SCHEMA_DEVICE( - { - conf.CONF_DEVICE_PATH: "/dev/serial", - conf.CONF_DEVICE_BAUDRATE: 115200, - } - ), - appmock, - ) - - # Ensure all threads are cleaned up - [t.join(1) for t in threading.enumerate() if "bellows" in t.name] - threads = [t for t in threading.enumerate() if "bellows" in t.name] - assert len(threads) == 0 - - -async def test_connect_threaded_failure_cancellation_propagation(monkeypatch): - appmock = MagicMock() - - async def mock_connect(loop, protocol_factory, *args, **kwargs): - protocol = protocol_factory() - transport = AsyncMock() - - protocol.connection_made(transport) - - return transport, protocol - - with patch("bellows.uart.zigpy.serial.create_serial_connection", mock_connect): - gw = await uart.connect( - conf.SCHEMA_DEVICE( - { - conf.CONF_DEVICE_PATH: "/dev/serial", - conf.CONF_DEVICE_BAUDRATE: 115200, - } - ), - appmock, - use_thread=True, - ) - - # Begin waiting for the startup reset - wait_for_reset = gw.wait_for_startup_reset() - - # But lose connection halfway through - asyncio.get_running_loop().call_later(0.1, gw.connection_lost, RuntimeError()) - - # Cancellation should propagate to the outer loop - with pytest.raises(RuntimeError): - await wait_for_reset - - # Ensure all threads are cleaned up - [t.join(1) for t in threading.enumerate() if "bellows" in t.name] - threads = [t for t in threading.enumerate() if "bellows" in t.name] - assert len(threads) == 0 - @pytest.fixture def gw(): @@ -383,7 +278,6 @@ def on_transport_close(): {conf.CONF_DEVICE_PATH: "/dev/serial", conf.CONF_DEVICE_BAUDRATE: 115200} ), app, - use_thread=False, # required until #484 is merged ) asyncio.get_running_loop().call_later(0.1, gw.connection_lost, ValueError()) @@ -391,14 +285,8 @@ def on_transport_close(): with pytest.raises(ValueError): await gw.reset() - # Need to close to release thread gw.close() - # Ensure all threads are cleaned up - [t.join(1) for t in threading.enumerate() if "bellows" in t.name] - threads = [t for t in threading.enumerate() if "bellows" in t.name] - assert len(threads) == 0 - async def test_wait_for_startup_reset(gw): loop = asyncio.get_running_loop()