Skip to content

Commit

Permalink
feat: add some typing (#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
kalombos authored Aug 19, 2024
1 parent 6ef315c commit cb7cc1e
Show file tree
Hide file tree
Showing 14 changed files with 113 additions and 93 deletions.
49 changes: 28 additions & 21 deletions peewee_async/aio_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import peewee
from peewee import PREFETCH_TYPE

from .databases import AioDatabase
from .result_wrappers import fetch_models
from .utils import CursorProtocol
from typing_extensions import Self
from typing import Tuple, List, Any, cast


async def aio_prefetch(sq, *subqueries, prefetch_type):
async def aio_prefetch(sq, *subqueries, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE) -> List[Any]:
"""Asynchronous version of `prefetch()`.
See also:
Expand Down Expand Up @@ -42,10 +46,10 @@ async def aio_prefetch(sq, *subqueries, prefetch_type):

class AioQueryMixin:
@peewee.database_required
async def aio_execute(self, database):
async def aio_execute(self, database: AioDatabase) -> Any:
return await database.aio_execute(self)

async def fetch_results(self, cursor: CursorProtocol):
async def fetch_results(self, cursor: CursorProtocol) -> List[Any]:
return await fetch_models(cursor, self)


Expand Down Expand Up @@ -116,7 +120,7 @@ async def aio_get(self, database=None):
(clone.model, sql, params))

@peewee.database_required
async def aio_count(self, database, clear_limit=False):
async def aio_count(self, database, clear_limit=False) -> int:
"""
Async version of **peewee.SelectBase.count**
Expand All @@ -133,7 +137,10 @@ async def aio_count(self, database, clear_limit=False):
clone = clone.select(peewee.SQL('1'))
except AttributeError:
pass
return await AioSelect([clone], [peewee.fn.COUNT(peewee.SQL('1'))]).aio_scalar(database)
return cast(
int,
await AioSelect([clone], [peewee.fn.COUNT(peewee.SQL('1'))]).aio_scalar(database)
)

@peewee.database_required
async def aio_exists(self, database):
Expand Down Expand Up @@ -164,14 +171,14 @@ def except_(self, rhs):
return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs)
__sub__ = except_

def aio_prefetch(self, *subqueries, **kwargs):
def aio_prefetch(self, *subqueries, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE):
"""
Async version of **peewee.ModelSelect.prefetch**
See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#ModelSelect.prefetch
"""
return aio_prefetch(self, *subqueries, **kwargs)
return aio_prefetch(self, *subqueries, prefetch_type=prefetch_type)


class AioSelect(AioSelectMixin, peewee.Select):
Expand Down Expand Up @@ -207,39 +214,39 @@ class User(peewee_async.AioModel):
"""

@classmethod
def select(cls, *fields):
def select(cls, *fields) -> AioModelSelect:
is_default = not fields
if not fields:
fields = cls._meta.sorted_fields
return AioModelSelect(cls, fields, is_default=is_default)

@classmethod
def update(cls, __data=None, **update):
def update(cls, __data=None, **update) -> AioModelUpdate:
return AioModelUpdate(cls, cls._normalize_data(__data, update))

@classmethod
def insert(cls, __data=None, **insert):
def insert(cls, __data=None, **insert) -> AioModelInsert:
return AioModelInsert(cls, cls._normalize_data(__data, insert))

@classmethod
def insert_many(cls, rows, fields=None):
def insert_many(cls, rows, fields=None) -> AioModelInsert:
return AioModelInsert(cls, insert=rows, columns=fields)

@classmethod
def insert_from(cls, query, fields):
def insert_from(cls, query, fields) -> AioModelInsert:
columns = [getattr(cls, field) if isinstance(field, str)
else field for field in fields]
return AioModelInsert(cls, insert=query, columns=columns)

@classmethod
def raw(cls, sql, *params):
def raw(cls, sql, *params) -> AioModelRaw:
return AioModelRaw(cls, sql, params)

@classmethod
def delete(cls):
def delete(cls) -> AioModelDelete:
return AioModelDelete(cls)

async def aio_delete_instance(self, recursive=False, delete_nullable=False):
async def aio_delete_instance(self, recursive: bool = False, delete_nullable: bool = False) -> int:
"""
Async version of **peewee.Model.delete_instance**
Expand All @@ -254,9 +261,9 @@ async def aio_delete_instance(self, recursive=False, delete_nullable=False):
await model.update(**{fk.name: None}).where(query).aio_execute()
else:
await model.delete().where(query).aio_execute()
return await type(self).delete().where(self._pk_expr()).aio_execute()
return cast(int, await type(self).delete().where(self._pk_expr()).aio_execute())

async def aio_save(self, force_insert=False, only=None):
async def aio_save(self, force_insert: bool = False, only=None) -> int:
"""
Async version of **peewee.Model.save**
Expand Down Expand Up @@ -306,7 +313,7 @@ async def aio_save(self, force_insert=False, only=None):
return rows

@classmethod
async def aio_get(cls, *query, **filters):
async def aio_get(cls, *query, **filters) -> Self:
"""Async version of **peewee.Model.get**
See also:
Expand All @@ -323,7 +330,7 @@ async def aio_get(cls, *query, **filters):
return await sq.aio_get()

@classmethod
async def aio_get_or_none(cls, *query, **filters):
async def aio_get_or_none(cls, *query, **filters) -> Self | None:
"""
Async version of **peewee.Model.get_or_none**
Expand All @@ -336,7 +343,7 @@ async def aio_get_or_none(cls, *query, **filters):
return None

@classmethod
async def aio_create(cls, **query):
async def aio_create(cls, **query) -> "Self":
"""
Async version of **peewee.Model.create**
Expand All @@ -348,7 +355,7 @@ async def aio_create(cls, **query):
return inst

@classmethod
async def aio_get_or_create(cls, **kwargs):
async def aio_get_or_create(cls, **kwargs) -> Tuple[Self, bool]:
"""
Async version of **peewee.Model.get_or_create**
Expand Down
2 changes: 1 addition & 1 deletion peewee_async/result_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def close(self) -> None:
pass


async def fetch_models(cursor: CursorProtocol, query: BaseQuery):
async def fetch_models(cursor: CursorProtocol, query: BaseQuery) -> List[Any]:
rows = await cursor.fetchall()
sync_cursor = SyncCursorAdapter(rows, cursor.description)
_result_wrapper = query._get_cursor_wrapper(sync_cursor)
Expand Down
5 changes: 4 additions & 1 deletion peewee_async/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Protocol, Optional, Sequence, Set, AsyncContextManager
from typing import Any, Protocol, Optional, Sequence, Set, AsyncContextManager, List

try:
import aiopg
Expand All @@ -23,6 +23,9 @@ class CursorProtocol(Protocol):
async def fetchone(self) -> Any:
...

async def fetchall(self) -> List[Any]:
...

@property
def lastrowid(self) -> int:
...
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ disallow_any_generics = True
disallow_untyped_calls = True
disallow_untyped_defs = True
disallow_incomplete_defs = True
exclude = (venv|load-testing|examples)
exclude = (venv|load-testing|examples|docs)
7 changes: 4 additions & 3 deletions tests/aio_model/test_deleting.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import uuid

from peewee_async.databases import AioDatabase
from tests.conftest import dbs_all, dbs_postgres
from tests.models import TestModel
from tests.utils import model_has_fields


@dbs_all
async def test_delete__count(db):
async def test_delete__count(db: AioDatabase) -> None:
query = TestModel.insert_many([
{'text': "Test %s" % uuid.uuid4()},
{'text': "Test %s" % uuid.uuid4()},
Expand All @@ -19,7 +20,7 @@ async def test_delete__count(db):


@dbs_all
async def test_delete__by_condition(db):
async def test_delete__by_condition(db: AioDatabase) -> None:
expected_text = "text1"
deleted_text = "text2"
query = TestModel.insert_many([
Expand All @@ -36,7 +37,7 @@ async def test_delete__by_condition(db):


@dbs_postgres
async def test_delete__return_model(db):
async def test_delete__return_model(db: AioDatabase) -> None:
m = await TestModel.aio_create(text="text", data="data")

res = await TestModel.delete().returning(TestModel).aio_execute()
Expand Down
17 changes: 9 additions & 8 deletions tests/aio_model/test_inserting.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import uuid

from peewee_async.databases import AioDatabase
from tests.conftest import dbs_all, dbs_postgres
from tests.models import TestModel, UUIDTestModel
from tests.utils import model_has_fields


@dbs_all
async def test_insert_many(db):
async def test_insert_many(db: AioDatabase) -> None:
last_id = await TestModel.insert_many([
{'text': "Test %s" % uuid.uuid4()},
{'text': "Test %s" % uuid.uuid4()},
Expand All @@ -19,7 +20,7 @@ async def test_insert_many(db):


@dbs_all
async def test_insert__return_id(db):
async def test_insert__return_id(db: AioDatabase) -> None:
last_id = await TestModel.insert(text="Test %s" % uuid.uuid4()).aio_execute()

res = await TestModel.select().aio_execute()
Expand All @@ -28,7 +29,7 @@ async def test_insert__return_id(db):


@dbs_postgres
async def test_insert_on_conflict_ignore__last_id_is_none(db):
async def test_insert_on_conflict_ignore__last_id_is_none(db: AioDatabase) -> None:
query = TestModel.insert(text="text").on_conflict_ignore()
await query.aio_execute()

Expand All @@ -38,7 +39,7 @@ async def test_insert_on_conflict_ignore__last_id_is_none(db):


@dbs_postgres
async def test_insert_on_conflict_ignore__return_model(db):
async def test_insert_on_conflict_ignore__return_model(db: AioDatabase) -> None:
query = TestModel.insert(text="text", data="data").on_conflict_ignore().returning(TestModel)

res = await query.aio_execute()
Expand All @@ -55,7 +56,7 @@ async def test_insert_on_conflict_ignore__return_model(db):


@dbs_postgres
async def test_insert_on_conflict_ignore__inserted_once(db):
async def test_insert_on_conflict_ignore__inserted_once(db: AioDatabase) -> None:
query = TestModel.insert(text="text").on_conflict_ignore()
last_id = await query.aio_execute()

Expand All @@ -67,14 +68,14 @@ async def test_insert_on_conflict_ignore__inserted_once(db):


@dbs_postgres
async def test_insert__uuid_pk(db):
async def test_insert__uuid_pk(db: AioDatabase) -> None:
query = UUIDTestModel.insert(text="Test %s" % uuid.uuid4())
last_id = await query.aio_execute()
assert len(str(last_id)) == 36


@dbs_postgres
async def test_insert__return_model(db):
async def test_insert__return_model(db: AioDatabase) -> None:
text = "Test %s" % uuid.uuid4()
data = "data"
query = TestModel.insert(text=text, data=data).returning(TestModel)
Expand All @@ -88,7 +89,7 @@ async def test_insert__return_model(db):


@dbs_postgres
async def test_insert_many__return_model(db):
async def test_insert_many__return_model(db: AioDatabase) -> None:
texts = [f"text{n}" for n in range(2)]
query = TestModel.insert_many([
{"text": text} for text in texts
Expand Down
17 changes: 9 additions & 8 deletions tests/aio_model/test_selecting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from peewee_async.aio_model import AioModelCompoundSelectQuery, AioModelRaw
from peewee_async.databases import AioDatabase
from tests.conftest import dbs_all
from tests.models import TestModel, TestModelAlpha, TestModelBeta


@dbs_all
async def test_select_w_join(db):
async def test_select_w_join(db: AioDatabase) -> None:
alpha = await TestModelAlpha.aio_create(text="Test 1")
beta = await TestModelBeta.aio_create(alpha_id=alpha.id, text="text")

Expand All @@ -18,7 +19,7 @@ async def test_select_w_join(db):


@dbs_all
async def test_raw_select(db):
async def test_raw_select(db: AioDatabase) -> None:
obj1 = await TestModel.aio_create(text="Test 1")
obj2 = await TestModel.aio_create(text="Test 2")
query = TestModel.raw(
Expand All @@ -30,23 +31,23 @@ async def test_raw_select(db):


@dbs_all
async def test_tuples(db):
async def test_tuples(db: AioDatabase) -> None:
obj = await TestModel.aio_create(text="Test 1")

result = await TestModel.select(TestModel.id, TestModel.text).tuples().aio_execute()
assert result[0] == (obj.id, obj.text)


@dbs_all
async def test_dicts(db):
async def test_dicts(db: AioDatabase) -> None:
obj = await TestModel.aio_create(text="Test 1")

result = await TestModel.select(TestModel.id, TestModel.text).dicts().aio_execute()
assert result[0] == {"id": obj.id, "text": obj.text}


@dbs_all
async def test_union_all(db):
async def test_union_all(db: AioDatabase) -> None:
obj1 = await TestModel.aio_create(text="1")
obj2 = await TestModel.aio_create(text="2")
query = (
Expand All @@ -59,7 +60,7 @@ async def test_union_all(db):


@dbs_all
async def test_union(db):
async def test_union(db: AioDatabase) -> None:
obj1 = await TestModel.aio_create(text="1")
obj2 = await TestModel.aio_create(text="2")
query = (
Expand All @@ -73,7 +74,7 @@ async def test_union(db):


@dbs_all
async def test_intersect(db):
async def test_intersect(db: AioDatabase) -> None:
await TestModel.aio_create(text="1")
await TestModel.aio_create(text="2")
await TestModel.aio_create(text="3")
Expand All @@ -90,7 +91,7 @@ async def test_intersect(db):


@dbs_all
async def test_except(db):
async def test_except(db: AioDatabase) -> None:
await TestModel.aio_create(text="1")
await TestModel.aio_create(text="2")
await TestModel.aio_create(text="3")
Expand Down
Loading

0 comments on commit cb7cc1e

Please sign in to comment.