Skip to content

Commit

Permalink
chore: Refactor Async(Postgresql/MySQL)Connection (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
akerlay authored Apr 8, 2024
1 parent 9e3dda8 commit c93271e
Showing 1 changed file with 50 additions and 69 deletions.
119 changes: 50 additions & 69 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Copyright (c) 2014, Alexey Kinëv <[email protected]>
"""
import abc
import asyncio
import contextlib
import functools
Expand Down Expand Up @@ -712,7 +713,7 @@ async def connect_async(self, loop=None, timeout=None):
timeout=timeout,
**self.connect_params_async
)
await conn.connect()
await conn.create()
self._async_conn = conn

async def cursor_async(self):
Expand All @@ -734,7 +735,7 @@ async def close_async(self):
if self._async_conn:
conn = self._async_conn
self._async_conn = None
await conn.close()
await conn.terminate()

async def push_transaction_async(self):
"""Increment async transaction depth.
Expand Down Expand Up @@ -851,19 +852,14 @@ async def aio_execute(self, query):
return (await coroutine(query))


##############
# PostgreSQL #
##############


class AsyncPostgresqlConnection:
class AioPool(metaclass=abc.ABCMeta):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
self.pool = None
self.loop = loop
self.database = database
self.timeout = timeout or aiopg.DEFAULT_TIMEOUT
self.timeout = timeout
self.connect_params = kwargs

async def acquire(self):
Expand All @@ -876,24 +872,20 @@ def release(self, conn):
"""
self.pool.release(conn)

async def connect(self):
@abc.abstractmethod
async def create(self):
"""Create connection pool asynchronously.
"""
self.pool = await aiopg.create_pool(
loop=self.loop,
timeout=self.timeout,
database=self.database,
**self.connect_params)
raise NotImplementedError

async def close(self):
async def terminate(self):
"""Terminate all pool connections.
"""
self.pool.terminate()
await self.pool.wait_closed()

async def cursor(self, conn=None, *args, **kwargs):
"""Get a cursor for the specified transaction connection
or acquire from the pool.
"""Get cursor for connection from pool.
"""
in_transaction = conn is not None
if not conn:
Expand All @@ -914,10 +906,44 @@ async def release_cursor(self, cursor, in_transaction=False):
the connection is also released back to the pool.
"""
conn = cursor.connection
cursor.close()
await self.close_cursor(cursor)
if not in_transaction:
self.release(conn)

@abc.abstractmethod
async def close_cursor(self, cursor):
raise NotImplementedError



##############
# PostgreSQL #
##############


class AioPostgresqlPool(AioPool):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
super().__init__(
database=database,
loop=loop,
timeout=timeout or aiopg.DEFAULT_TIMEOUT,
**kwargs,
)

async def create(self):
"""Create connection pool asynchronously.
"""
self.pool = await aiopg.create_pool(
loop=self.loop,
timeout=self.timeout,
database=self.database,
**self.connect_params)

async def close_cursor(self, cursor):
cursor.close()


class AsyncPostgresqlMixin(AsyncDatabase):
"""Mixin for `peewee.PostgresqlDatabase` providing extra methods
Expand All @@ -926,7 +952,7 @@ class AsyncPostgresqlMixin(AsyncDatabase):
if psycopg2:
Error = psycopg2.Error

def init_async(self, conn_cls=AsyncPostgresqlConnection,
def init_async(self, conn_cls=AioPostgresqlPool,
enable_json=False, enable_hstore=False):
if not aiopg:
raise Exception("Error, aiopg is not installed!")
Expand Down Expand Up @@ -1027,27 +1053,11 @@ def use_speedups(self, value):
#########


class AsyncMySQLConnection:
class AioMysqlPool(AioPool):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
self.pool = None
self.loop = loop
self.database = database
self.timeout = timeout
self.connect_params = kwargs

async def acquire(self):
"""Acquire connection from pool.
"""
return (await self.pool.acquire())

def release(self, conn):
"""Release connection to pool.
"""
self.pool.release(conn)

async def connect(self):
async def create(self):
"""Create connection pool asynchronously.
"""
self.pool = await aiomysql.create_pool(
Expand All @@ -1056,37 +1066,8 @@ async def connect(self):
connect_timeout=self.timeout,
**self.connect_params)

async def close(self):
"""Terminate all pool connections.
"""
self.pool.terminate()
await self.pool.wait_closed()

async def cursor(self, conn=None, *args, **kwargs):
"""Get cursor for connection from pool.
"""
in_transaction = conn is not None
if not conn:
conn = await self.acquire()
try:
cursor = await conn.cursor(*args, **kwargs)
except:
if not in_transaction:
self.release(conn)
raise
cursor.release = functools.partial(
self.release_cursor, cursor,
in_transaction=in_transaction)
return cursor

async def release_cursor(self, cursor, in_transaction=False):
"""Release cursor coroutine. Unless in transaction,
the connection is also released back to the pool.
"""
conn = cursor.connection
async def close_cursor(self, cursor):
await cursor.close()
if not in_transaction:
self.release(conn)


class MySQLDatabase(AsyncDatabase, peewee.MySQLDatabase):
Expand All @@ -1108,7 +1089,7 @@ def init(self, database, **kwargs):
raise Exception("Error, aiomysql is not installed!")
self.min_connections = 1
self.max_connections = 1
self._async_conn_cls = kwargs.pop('async_conn', AsyncMySQLConnection)
self._async_conn_cls = kwargs.pop('async_conn', AioMysqlPool)
super().init(database, **kwargs)

@property
Expand Down

0 comments on commit c93271e

Please sign in to comment.