diff --git a/docs/deployment.md b/docs/deployment.md index d69fcf88e..4d5819011 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -57,7 +57,11 @@ Options: --workers INTEGER Number of worker processes. Defaults to the $WEB_CONCURRENCY environment variable if available, or 1. Not valid with --reload. - --loop [auto|asyncio|uvloop] Event loop implementation. [default: auto] + --loop TEXT Event loop implementation. Can be one of + [auto|asyncio|uvloop] or an import string to + a function of type: (use_subprocess: bool) + -> Callable[[], asyncio.AbstractEventLoop]. + [default: auto] --http [auto|h11|httptools] HTTP protocol implementation. [default: auto] --ws [auto|none|websockets|wsproto] diff --git a/docs/index.md b/docs/index.md index bb6fc321a..20da6442b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -127,7 +127,11 @@ Options: --workers INTEGER Number of worker processes. Defaults to the $WEB_CONCURRENCY environment variable if available, or 1. Not valid with --reload. - --loop [auto|asyncio|uvloop] Event loop implementation. [default: auto] + --loop TEXT Event loop implementation. Can be one of + [auto|asyncio|uvloop] or an import string to + a function of type: (use_subprocess: bool) + -> Callable[[], asyncio.AbstractEventLoop]. + [default: auto] --http [auto|h11|httptools] HTTP protocol implementation. [default: auto] --ws [auto|none|websockets|wsproto] diff --git a/pyproject.toml b/pyproject.toml index 6dd4916db..7395272aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,7 +103,7 @@ filterwarnings = [ [tool.coverage.run] source_pkgs = ["uvicorn", "tests"] plugins = ["coverage_conditional_plugin"] -omit = ["uvicorn/workers.py", "uvicorn/__main__.py"] +omit = ["uvicorn/workers.py", "uvicorn/__main__.py", "uvicorn/_compat.py"] [tool.coverage.report] precision = 2 diff --git a/tests/custom_loop_utils.py b/tests/custom_loop_utils.py new file mode 100644 index 000000000..3a2db4a78 --- /dev/null +++ b/tests/custom_loop_utils.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import asyncio +from asyncio import AbstractEventLoop + + +class CustomLoop(asyncio.SelectorEventLoop): + pass + + +def custom_loop_factory(use_subprocess: bool) -> type[AbstractEventLoop]: + return CustomLoop diff --git a/tests/test_auto_detection.py b/tests/test_auto_detection.py index 1f79b3786..ef86bf265 100644 --- a/tests/test_auto_detection.py +++ b/tests/test_auto_detection.py @@ -1,10 +1,11 @@ import asyncio +import contextlib import importlib import pytest from uvicorn.config import Config -from uvicorn.loops.auto import auto_loop_setup +from uvicorn.loops.auto import auto_loop_factory from uvicorn.main import ServerState from uvicorn.protocols.http.auto import AutoHTTPProtocol from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol @@ -33,10 +34,10 @@ async def app(scope, receive, send): def test_loop_auto(): - auto_loop_setup() - policy = asyncio.get_event_loop_policy() - assert isinstance(policy, asyncio.events.BaseDefaultEventLoopPolicy) - assert type(policy).__module__.startswith(expected_loop) + loop_factory = auto_loop_factory(use_subprocess=True) + with contextlib.closing(loop_factory()) as loop: + assert isinstance(loop, asyncio.AbstractEventLoop) + assert type(loop).__module__.startswith(expected_loop) @pytest.mark.anyio diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 000000000..15af6a4eb --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import asyncio +from asyncio import AbstractEventLoop + +import pytest + +from tests.custom_loop_utils import CustomLoop, custom_loop_factory +from tests.utils import get_asyncio_default_loop_per_os +from uvicorn._compat import asyncio_run + + +async def assert_event_loop(expected_loop_class: type[AbstractEventLoop]): + assert isinstance(asyncio.get_event_loop(), expected_loop_class) + + +def test_asyncio_run__default_loop_factory() -> None: + asyncio_run(assert_event_loop(get_asyncio_default_loop_per_os()), loop_factory=None) + + +def test_asyncio_run__custom_loop_factory() -> None: + asyncio_run(assert_event_loop(CustomLoop), loop_factory=custom_loop_factory(use_subprocess=False)) + + +def test_asyncio_run__passing_a_non_awaitable_callback_should_throw_error() -> None: + with pytest.raises(ValueError): + asyncio_run( + lambda: None, # type: ignore + loop_factory=custom_loop_factory(use_subprocess=False), + ) diff --git a/tests/test_config.py b/tests/test_config.py index e16cc5d56..fcbce4880 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -8,6 +8,7 @@ import socket import sys import typing +from contextlib import closing from pathlib import Path from typing import Any, Literal from unittest.mock import MagicMock @@ -16,7 +17,8 @@ import yaml from pytest_mock import MockerFixture -from tests.utils import as_cwd +from tests.custom_loop_utils import CustomLoop +from tests.utils import as_cwd, get_asyncio_default_loop_per_os from uvicorn._types import ( ASGIApplication, ASGIReceiveCallable, @@ -25,7 +27,7 @@ Scope, StartResponse, ) -from uvicorn.config import Config +from uvicorn.config import Config, LoopFactoryType from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware from uvicorn.middleware.wsgi import WSGIMiddleware from uvicorn.protocols.http.h11_impl import H11Protocol @@ -545,3 +547,48 @@ def test_warn_when_using_reload_and_workers(caplog: pytest.LogCaptureFixture) -> Config(app=asgi_app, reload=True, workers=2) assert len(caplog.records) == 1 assert '"workers" flag is ignored when reloading is enabled.' in caplog.records[0].message + + +@pytest.mark.parametrize( + ("loop_type", "expected_loop_factory"), + [ + ("none", None), + ("asyncio", get_asyncio_default_loop_per_os()), + ], +) +def test_get_loop_factory(loop_type: LoopFactoryType, expected_loop_factory: Any): + config = Config(app=asgi_app, loop=loop_type) + loop_factory = config.get_loop_factory() + if loop_factory is None: + assert expected_loop_factory is loop_factory + else: + loop = loop_factory() + with closing(loop): + assert loop is not None + assert isinstance(loop, expected_loop_factory) + + +def test_custom_loop__importable_custom_loop_setup_function() -> None: + config = Config(app=asgi_app, loop="tests.custom_loop_utils:custom_loop_factory") + config.load() + loop_factory = config.get_loop_factory() + assert loop_factory, "Loop factory should be set" + event_loop = loop_factory() + with closing(event_loop): + assert event_loop is not None + assert isinstance(event_loop, CustomLoop) + + +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") +def test_custom_loop__not_importable_custom_loop_setup_function(caplog: pytest.LogCaptureFixture) -> None: + config = Config(app=asgi_app, loop="tests.test_config:non_existing_setup_function") + config.load() + with pytest.raises(SystemExit): + config.get_loop_factory() + error_messages = [ + record.message for record in caplog.records if record.name == "uvicorn.error" and record.levelname == "ERROR" + ] + assert ( + 'Error loading custom loop setup function. Attribute "non_existing_setup_function" not found in module "tests.test_config".' # noqa: E501 + == error_messages.pop(0) + ) diff --git a/tests/utils.py b/tests/utils.py index 56362f20f..8145a2bd2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ import asyncio import os import signal +import sys from collections.abc import AsyncIterator from contextlib import asynccontextmanager, contextmanager from pathlib import Path @@ -44,3 +45,11 @@ def as_cwd(path: Path): yield finally: os.chdir(prev_cwd) + + +def get_asyncio_default_loop_per_os() -> type[asyncio.AbstractEventLoop]: + """Get the default asyncio loop per OS.""" + if sys.platform == "win32": + return asyncio.ProactorEventLoop # type: ignore # pragma: nocover + else: + return asyncio.SelectorEventLoop # pragma: nocover diff --git a/uvicorn/_compat.py b/uvicorn/_compat.py new file mode 100644 index 000000000..e2650507a --- /dev/null +++ b/uvicorn/_compat.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import asyncio +import sys +from collections.abc import Callable, Coroutine +from typing import Any, TypeVar + +_T = TypeVar("_T") + +if sys.version_info >= (3, 12): + asyncio_run = asyncio.run +elif sys.version_info >= (3, 11): + + def asyncio_run( + main: Coroutine[Any, Any, _T], + *, + debug: bool = False, + loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None, + ) -> _T: + # asyncio.run from Python 3.12 + # https://docs.python.org/3/license.html#psf-license + with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner: + return runner.run(main) + +else: + # modified version of asyncio.run from Python 3.10 to add loop_factory kwarg + # https://docs.python.org/3/license.html#psf-license + def asyncio_run( + main: Coroutine[Any, Any, _T], + *, + debug: bool = False, + loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None, + ) -> _T: + try: + asyncio.get_running_loop() + except RuntimeError: + pass + else: + raise RuntimeError("asyncio.run() cannot be called from a running event loop") + + if not asyncio.iscoroutine(main): + raise ValueError(f"a coroutine was expected, got {main!r}") + + if loop_factory is None: + loop = asyncio.new_event_loop() + else: + loop = loop_factory() + try: + if loop_factory is None: + asyncio.set_event_loop(loop) + if debug is not None: + loop.set_debug(debug) + return loop.run_until_complete(main) + finally: + try: + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + if sys.version_info >= (3, 9): + loop.run_until_complete(loop.shutdown_default_executor()) + finally: + if loop_factory is None: + asyncio.set_event_loop(None) + loop.close() + + def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None: + to_cancel = asyncio.all_tasks(loop) + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) diff --git a/uvicorn/config.py b/uvicorn/config.py index 65dfe651e..2f3a58ef2 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -26,7 +26,7 @@ HTTPProtocolType = Literal["auto", "h11", "httptools"] WSProtocolType = Literal["auto", "none", "websockets", "wsproto"] LifespanType = Literal["auto", "on", "off"] -LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"] +LoopFactoryType = Literal["none", "auto", "asyncio", "uvloop"] InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"] LOG_LEVELS: dict[str, int] = { @@ -53,11 +53,11 @@ "on": "uvicorn.lifespan.on:LifespanOn", "off": "uvicorn.lifespan.off:LifespanOff", } -LOOP_SETUPS: dict[LoopSetupType, str | None] = { +LOOP_FACTORIES: dict[str, str | None] = { "none": None, - "auto": "uvicorn.loops.auto:auto_loop_setup", - "asyncio": "uvicorn.loops.asyncio:asyncio_setup", - "uvloop": "uvicorn.loops.uvloop:uvloop_setup", + "auto": "uvicorn.loops.auto:auto_loop_factory", + "asyncio": "uvicorn.loops.asyncio:asyncio_loop_factory", + "uvloop": "uvicorn.loops.uvloop:uvloop_loop_factory", } INTERFACES: list[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"] @@ -180,7 +180,7 @@ def __init__( port: int = 8000, uds: str | None = None, fd: int | None = None, - loop: LoopSetupType = "auto", + loop: str = "auto", http: type[asyncio.Protocol] | HTTPProtocolType = "auto", ws: type[asyncio.Protocol] | WSProtocolType = "auto", ws_max_size: int = 16 * 1024 * 1024, @@ -471,10 +471,18 @@ def load(self) -> None: self.loaded = True - def setup_event_loop(self) -> None: - loop_setup: Callable | None = import_from_string(LOOP_SETUPS[self.loop]) - if loop_setup is not None: - loop_setup(use_subprocess=self.use_subprocess) + def get_loop_factory(self) -> Callable[[], asyncio.AbstractEventLoop] | None: + if self.loop in LOOP_FACTORIES: + loop_factory: Callable | None = import_from_string(LOOP_FACTORIES[self.loop]) + else: + try: + loop_factory = import_from_string(self.loop) + except ImportFromStringError as exc: + logger.error("Error loading custom loop setup function. %s" % exc) + sys.exit(1) + if loop_factory is None: + return None + return loop_factory(use_subprocess=self.use_subprocess) def bind_socket(self) -> socket.socket: logger_args: list[str | int] diff --git a/uvicorn/loops/asyncio.py b/uvicorn/loops/asyncio.py index 1bead4a06..ad6121ee0 100644 --- a/uvicorn/loops/asyncio.py +++ b/uvicorn/loops/asyncio.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import asyncio -import logging import sys - -logger = logging.getLogger("uvicorn.error") +from collections.abc import Callable -def asyncio_setup(use_subprocess: bool = False) -> None: - if sys.platform == "win32" and use_subprocess: - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # pragma: full coverage +def asyncio_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]: + if sys.platform == "win32" and not use_subprocess: + return asyncio.ProactorEventLoop + return asyncio.SelectorEventLoop diff --git a/uvicorn/loops/auto.py b/uvicorn/loops/auto.py index 2285457bf..190839905 100644 --- a/uvicorn/loops/auto.py +++ b/uvicorn/loops/auto.py @@ -1,11 +1,17 @@ -def auto_loop_setup(use_subprocess: bool = False) -> None: +from __future__ import annotations + +import asyncio +from collections.abc import Callable + + +def auto_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]: try: import uvloop # noqa except ImportError: # pragma: no cover - from uvicorn.loops.asyncio import asyncio_setup as loop_setup + from uvicorn.loops.asyncio import asyncio_loop_factory as loop_factory - loop_setup(use_subprocess=use_subprocess) + return loop_factory(use_subprocess=use_subprocess) else: # pragma: no cover - from uvicorn.loops.uvloop import uvloop_setup + from uvicorn.loops.uvloop import uvloop_loop_factory - uvloop_setup(use_subprocess=use_subprocess) + return uvloop_loop_factory(use_subprocess=use_subprocess) diff --git a/uvicorn/loops/uvloop.py b/uvicorn/loops/uvloop.py index 0e2fd1eb0..c6692c58f 100644 --- a/uvicorn/loops/uvloop.py +++ b/uvicorn/loops/uvloop.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import asyncio +from collections.abc import Callable import uvloop -def uvloop_setup(use_subprocess: bool = False) -> None: - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +def uvloop_loop_factory(use_subprocess: bool = False) -> Callable[[], asyncio.AbstractEventLoop]: + return uvloop.new_event_loop diff --git a/uvicorn/main.py b/uvicorn/main.py index 96a10d538..75740cc89 100644 --- a/uvicorn/main.py +++ b/uvicorn/main.py @@ -19,14 +19,13 @@ LIFESPAN, LOG_LEVELS, LOGGING_CONFIG, - LOOP_SETUPS, + LOOP_FACTORIES, SSL_PROTOCOL_VERSION, WS_PROTOCOLS, Config, HTTPProtocolType, InterfaceType, LifespanType, - LoopSetupType, WSProtocolType, ) from uvicorn.server import Server, ServerState # noqa: F401 # Used to be defined here. @@ -36,7 +35,7 @@ HTTP_CHOICES = click.Choice(list(HTTP_PROTOCOLS.keys())) WS_CHOICES = click.Choice(list(WS_PROTOCOLS.keys())) LIFESPAN_CHOICES = click.Choice(list(LIFESPAN.keys())) -LOOP_CHOICES = click.Choice([key for key in LOOP_SETUPS.keys() if key != "none"]) +LOOP_CHOICES = [key for key in LOOP_FACTORIES.keys() if key != "none"] INTERFACE_CHOICES = click.Choice(INTERFACES) STARTUP_FAILURE = 3 @@ -117,9 +116,10 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No ) @click.option( "--loop", - type=LOOP_CHOICES, + type=str, default="auto", - help="Event loop implementation.", + help=f"Event loop implementation. Can be one of [{'|'.join(LOOP_CHOICES)}] " + f"or an import string to a function of type: (use_subprocess: bool) -> Callable[[], asyncio.AbstractEventLoop].", show_default=True, ) @click.option( @@ -366,7 +366,7 @@ def main( port: int, uds: str, fd: int, - loop: LoopSetupType, + loop: str, http: HTTPProtocolType, ws: WSProtocolType, ws_max_size: int, @@ -467,7 +467,7 @@ def run( port: int = 8000, uds: str | None = None, fd: int | None = None, - loop: LoopSetupType = "auto", + loop: str = "auto", http: type[asyncio.Protocol] | HTTPProtocolType = "auto", ws: type[asyncio.Protocol] | WSProtocolType = "auto", ws_max_size: int = 16777216, diff --git a/uvicorn/server.py b/uvicorn/server.py index f14026f16..10ea7b12b 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -16,6 +16,7 @@ import click +from uvicorn._compat import asyncio_run from uvicorn.config import Config if TYPE_CHECKING: @@ -61,8 +62,7 @@ def __init__(self, config: Config) -> None: self._captured_signals: list[int] = [] def run(self, sockets: list[socket.socket] | None = None) -> None: - self.config.setup_event_loop() - return asyncio.run(self.serve(sockets=sockets)) + return asyncio_run(self.serve(sockets=sockets), loop_factory=self.config.get_loop_factory()) async def serve(self, sockets: list[socket.socket] | None = None) -> None: with self.capture_signals(): diff --git a/uvicorn/workers.py b/uvicorn/workers.py index 061805b6c..25fa8533c 100644 --- a/uvicorn/workers.py +++ b/uvicorn/workers.py @@ -10,6 +10,7 @@ from gunicorn.arbiter import Arbiter from gunicorn.workers.base import Worker +from uvicorn._compat import asyncio_run from uvicorn.config import Config from uvicorn.main import Server @@ -70,10 +71,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.config = Config(**config_kwargs) - def init_process(self) -> None: - self.config.setup_event_loop() - super().init_process() - def init_signals(self) -> None: # Reset signals so Gunicorn doesn't swallow subprocess return codes # other signals are set up by Server.install_signal_handlers() @@ -104,7 +101,7 @@ async def _serve(self) -> None: sys.exit(Arbiter.WORKER_BOOT_ERROR) def run(self) -> None: - return asyncio.run(self._serve()) + return asyncio_run(self._serve(), loop_factory=self.config.get_loop_factory()) async def callback_notify(self) -> None: self.notify()