Skip to content

Commit

Permalink
Fix typing issues
Browse files Browse the repository at this point in the history
Following updates to various libraries used by Quart and mypy.
  • Loading branch information
pgjones committed Nov 14, 2024
1 parent e2e5642 commit a078901
Show file tree
Hide file tree
Showing 14 changed files with 75 additions and 68 deletions.
9 changes: 4 additions & 5 deletions src/quart/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from aiofiles import open as async_open
from aiofiles.base import AiofilesContextManager
from aiofiles.threadpool.binary import AsyncBufferedReader
from flask.sansio.app import App
from flask.sansio.scaffold import setupmethod
from hypercorn.asyncio import serve
Expand Down Expand Up @@ -125,7 +124,7 @@
try:
from typing import ParamSpec
except ImportError:
from typing_extensions import ParamSpec # type: ignore
from typing_extensions import ParamSpec

# Python 3.14 deprecated asyncio.iscoroutinefunction, but suggested
# inspect.iscoroutinefunction does not work correctly in some Python
Expand Down Expand Up @@ -384,7 +383,7 @@ async def open_resource(
self,
path: FilePath,
mode: str = "rb",
) -> AiofilesContextManager[None, None, AsyncBufferedReader]:
) -> AiofilesContextManager:
"""Open a file for reading.
Use as
Expand All @@ -401,7 +400,7 @@ async def open_resource(

async def open_instance_resource(
self, path: FilePath, mode: str = "rb"
) -> AiofilesContextManager[None, None, AsyncBufferedReader]:
) -> AiofilesContextManager:
"""Open a file for reading.
Use as
Expand Down Expand Up @@ -1402,7 +1401,7 @@ async def make_response(self, result: ResponseReturnValue | HTTPException) -> Re
response.status_code = int(status)

if headers is not None:
response.headers.update(headers) # type: ignore[arg-type]
response.headers.update(headers)

return response

Expand Down
3 changes: 1 addition & 2 deletions src/quart/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from aiofiles import open as async_open
from aiofiles.base import AiofilesContextManager
from aiofiles.threadpool.binary import AsyncBufferedReader
from flask.sansio.app import App
from flask.sansio.blueprints import ( # noqa
Blueprint as SansioBlueprint,
Expand Down Expand Up @@ -101,7 +100,7 @@ async def open_resource(
self,
path: FilePath,
mode: str = "rb",
) -> AiofilesContextManager[None, None, AsyncBufferedReader]:
) -> AiofilesContextManager:
"""Open a file for reading.
Use as
Expand Down
6 changes: 3 additions & 3 deletions src/quart/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ def get_flashed_messages(
all messages will be popped, but only those matching the filter
returned. See :func:`~quart.helpers.flash` for message creation.
"""
flashes = request_ctx.flashes
flashes: list[str] = request_ctx.flashes
if flashes is None:
flashes = session.pop("_flashes") if "_flashes" in session else []
request_ctx.flashes = flashes
flashes = session.pop("_flashes", [])
request_ctx.flashes = flashes # type: ignore[assignment]
if category_filter:
flashes = [flash for flash in flashes if flash[0] in category_filter]
if not with_categories:
Expand Down
2 changes: 1 addition & 1 deletion src/quart/testing/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_client(self) -> TestClientProtocol:
return self.app.test_client()

async def startup(self) -> None:
scope: LifespanScope = {"type": "lifespan", "asgi": {"spec_version": "2.0"}}
scope: LifespanScope = {"type": "lifespan", "asgi": {"spec_version": "2.0"}, "state": {}}
self._task = asyncio.ensure_future(self.app(scope, self._asgi_receive, self._asgi_send))
await self._app_queue.put({"type": "lifespan.startup"})
await asyncio.wait_for(self._startup.wait(), timeout=self.startup_timeout)
Expand Down
10 changes: 5 additions & 5 deletions src/quart/wrappers/request.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
from typing import Any, AnyStr, Awaitable, Callable, Generator, NoReturn, overload
from typing import Any, Awaitable, Callable, Generator, NoReturn, overload

from hypercorn.typing import HTTPScope
from werkzeug.datastructures import CombinedMultiDict, Headers, iter_multi_items, MultiDict
Expand Down Expand Up @@ -184,7 +184,7 @@ async def stream(self) -> NoReturn:

@property
async def data(self) -> bytes:
return await self.get_data(as_text=False, parse_form_data=True)
return await self.get_data(as_text=False, parse_form_data=True) # type: ignore

@overload
async def get_data(
Expand All @@ -197,16 +197,16 @@ async def get_data(self, cache: bool, as_text: Literal[True], parse_form_data: b
@overload
async def get_data(
self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False
) -> AnyStr: ...
) -> str | bytes: ...

async def get_data(
self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False
) -> AnyStr:
) -> str | bytes:
"""Get the request body data.
Arguments:
cache: If False the body data will be cleared, resulting in any
subsequent calls returning an empty AnyStr and reducing
subsequent calls returning an empty str | bytes and reducing
memory usage.
as_text: If True the data is returned as a decoded string,
otherwise raw bytes are returned.
Expand Down
19 changes: 9 additions & 10 deletions src/quart/wrappers/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from types import TracebackType
from typing import (
Any,
AnyStr,
AsyncGenerator,
AsyncIterable,
AsyncIterator,
Expand All @@ -19,7 +18,7 @@

from aiofiles import open as async_open
from aiofiles.base import AiofilesContextManager
from aiofiles.threadpool.binary import AsyncBufferedIOBase, AsyncBufferedReader
from aiofiles.threadpool.binary import AsyncBufferedIOBase
from werkzeug.datastructures import ContentRange, Headers
from werkzeug.exceptions import RequestedRangeNotSatisfiable
from werkzeug.http import parse_etags
Expand Down Expand Up @@ -148,7 +147,7 @@ def __init__(self, file_path: str | PathLike, *, buffer_size: int | None = None)
if buffer_size is not None:
self.buffer_size = buffer_size
self.file: AsyncBufferedIOBase | None = None
self.file_manager: AiofilesContextManager[None, None, AsyncBufferedReader] = None
self.file_manager: AiofilesContextManager = None

async def __aenter__(self) -> FileBody:
self.file_manager = async_open(self.file_path, mode="rb")
Expand Down Expand Up @@ -262,7 +261,7 @@ class Response(SansIOResponse):

def __init__(
self,
response: ResponseBody | AnyStr | Iterable | None = None,
response: ResponseBody | str | bytes | Iterable | None = None,
status: int | None = None,
headers: dict | Headers | None = None,
mimetype: str | None = None,
Expand Down Expand Up @@ -296,7 +295,7 @@ def __init__(
elif isinstance(response, ResponseBody):
self.response = response
elif isinstance(response, (str, bytes)):
self.set_data(response) # type: ignore
self.set_data(response)
else:
self.response = self.iterable_body_class(response)

Expand All @@ -314,9 +313,9 @@ async def get_data(self, as_text: Literal[True]) -> str: ...
async def get_data(self, as_text: Literal[False]) -> bytes: ...

@overload
async def get_data(self, as_text: bool = True) -> AnyStr: ...
async def get_data(self, as_text: bool = True) -> str | bytes: ...

async def get_data(self, as_text: bool = False) -> AnyStr:
async def get_data(self, as_text: bool = False) -> str | bytes:
"""Return the body data."""
if self.implicit_sequence_conversion:
await self.make_sequence()
Expand All @@ -327,9 +326,9 @@ async def get_data(self, as_text: bool = False) -> AnyStr:
result += data.decode()
else:
result += data
return result # type: ignore
return result

def set_data(self, data: AnyStr) -> None:
def set_data(self, data: str | bytes) -> None:
"""Set the response data.
This will encode using the :attr:`charset`.
Expand All @@ -344,7 +343,7 @@ def set_data(self, data: AnyStr) -> None:

@property
async def data(self) -> bytes:
return await self.get_data()
return await self.get_data(as_text=False)

@data.setter
def data(self, value: bytes) -> None:
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _http_scope() -> HTTPScope:
],
"client": ("127.0.0.1", 80),
"server": None,
"state": {}, # type: ignore[typeddict-item]
"extensions": {},
}

Expand All @@ -46,5 +47,6 @@ def _websocket_scope() -> WebsocketScope:
"client": ("127.0.0.1", 80),
"server": None,
"subprotocols": [],
"state": {}, # type: ignore[typeddict-item]
"extensions": {},
}
8 changes: 8 additions & 0 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async def test_http_1_0_host_header(headers: list, expected: str) -> None:
"client": ("127.0.0.1", 80),
"server": None,
"extensions": {},
"state": {}, # type: ignore[typeddict-item]
}
connection = ASGIHTTPConnection(app, scope)
request = connection._create_request_from_scope(lambda: None) # type: ignore
Expand All @@ -57,6 +58,7 @@ async def test_http_completion() -> None:
"client": ("127.0.0.1", 80),
"server": None,
"extensions": {},
"state": {}, # type: ignore[typeddict-item]
}
connection = ASGIHTTPConnection(app, scope)

Expand Down Expand Up @@ -98,6 +100,7 @@ async def test_http_request_without_body(request_message: dict) -> None:
"client": ("127.0.0.1", 80),
"server": None,
"extensions": {},
"state": {}, # type: ignore[typeddict-item]
}
connection = ASGIHTTPConnection(app, scope)
request = connection._create_request_from_scope(lambda: None) # type: ignore
Expand Down Expand Up @@ -135,6 +138,7 @@ async def test_websocket_completion() -> None:
"server": None,
"subprotocols": [],
"extensions": {"websocket.http.response": {}},
"state": {}, # type: ignore[typeddict-item]
}
connection = ASGIWebsocketConnection(app, scope)

Expand Down Expand Up @@ -168,6 +172,7 @@ def test_http_path_from_absolute_target() -> None:
"client": ("127.0.0.1", 80),
"server": None,
"extensions": {},
"state": {}, # type: ignore[typeddict-item]
}
connection = ASGIHTTPConnection(app, scope)
request = connection._create_request_from_scope(lambda: None) # type: ignore
Expand All @@ -194,6 +199,7 @@ def test_http_path_with_root_path(path: str, expected: str) -> None:
"client": ("127.0.0.1", 80),
"server": None,
"extensions": {},
"state": {}, # type: ignore[typeddict-item]
}
connection = ASGIHTTPConnection(app, scope)
request = connection._create_request_from_scope(lambda: None) # type: ignore
Expand All @@ -216,6 +222,7 @@ def test_websocket_path_from_absolute_target() -> None:
"server": None,
"subprotocols": [],
"extensions": {"websocket.http.response": {}},
"state": {}, # type: ignore[typeddict-item]
}
connection = ASGIWebsocketConnection(app, scope)
websocket = connection._create_websocket_from_scope(lambda: None) # type: ignore
Expand All @@ -242,6 +249,7 @@ def test_websocket_path_with_root_path(path: str, expected: str) -> None:
"server": None,
"subprotocols": [],
"extensions": {"websocket.http.response": {}},
"state": {}, # type: ignore[typeddict-item]
}
connection = ASGIWebsocketConnection(app, scope)
websocket = connection._create_websocket_from_scope(lambda: None) # type: ignore
Expand Down
18 changes: 9 additions & 9 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ async def test_index(path: str, app: Quart) -> None:
test_client = app.test_client()
response = await test_client.get(path)
assert response.status_code == 200
assert b"index" in (await response.get_data()) # type: ignore
assert b"index" in (await response.get_data())


async def test_iri(app: Quart) -> None:
test_client = app.test_client()
response = await test_client.get("/❤️")
assert "💔".encode() in (await response.get_data()) # type: ignore
assert "💔".encode() in (await response.get_data())


async def test_options(app: Quart) -> None:
Expand All @@ -107,35 +107,35 @@ async def test_json(app: Quart) -> None:
test_client = app.test_client()
response = await test_client.post("/json/", json={"value": "json"})
assert response.status_code == 200
assert b'{"value":"json"}\n' == (await response.get_data()) # type: ignore
assert b'{"value":"json"}\n' == (await response.get_data())


async def test_implicit_json(app: Quart) -> None:
test_client = app.test_client()
response = await test_client.post("/implicit_json/", json={"value": "json"})
assert response.status_code == 200
assert b'{"value":"json"}\n' == (await response.get_data()) # type: ignore
assert b'{"value":"json"}\n' == (await response.get_data())


async def test_implicit_json_list(app: Quart) -> None:
test_client = app.test_client()
response = await test_client.post("/implicit_json/", json=["a", 2])
assert response.status_code == 200
assert b'["a",2]\n' == (await response.get_data()) # type: ignore
assert b'["a",2]\n' == (await response.get_data())


async def test_werkzeug(app: Quart) -> None:
test_client = app.test_client()
response = await test_client.get("/werkzeug/")
assert response.status_code == 200
assert b"Hello" == (await response.get_data()) # type: ignore
assert b"Hello" == (await response.get_data())


async def test_generic_error(app: Quart) -> None:
test_client = app.test_client()
response = await test_client.get("/error/")
assert response.status_code == 409
assert b"Something Unique" in (await response.get_data()) # type: ignore
assert b"Something Unique" in (await response.get_data())


async def test_url_defaults(app: Quart) -> None:
Expand All @@ -151,7 +151,7 @@ async def test_not_found_error(app: Quart) -> None:
test_client = app.test_client()
response = await test_client.get("/not_found/")
assert response.status_code == 404
assert b"Not Found" in (await response.get_data()) # type: ignore
assert b"Not Found" in (await response.get_data())


async def test_make_response_str(app: Quart) -> None:
Expand Down Expand Up @@ -225,4 +225,4 @@ async def test_root_path(app: Quart) -> None:
async def test_stream(app: Quart) -> None:
test_client = app.test_client()
response = await test_client.get("/stream")
assert (await response.get_data()) == b"Hello World" # type: ignore
assert (await response.get_data()) == b"Hello World"
Loading

0 comments on commit a078901

Please sign in to comment.