-
Notifications
You must be signed in to change notification settings - Fork 98
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: Refactor Async(Postgresql/MySQL)Connection (#213)
- Loading branch information
Showing
1 changed file
with
50 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
Copyright (c) 2014, Alexey Kinëv <[email protected]> | ||
""" | ||
import abc | ||
import asyncio | ||
import contextlib | ||
import functools | ||
|
@@ -712,7 +713,7 @@ async def connect_async(self, loop=None, timeout=None): | |
timeout=timeout, | ||
**self.connect_params_async | ||
) | ||
await conn.connect() | ||
await conn.create() | ||
self._async_conn = conn | ||
|
||
async def cursor_async(self): | ||
|
@@ -734,7 +735,7 @@ async def close_async(self): | |
if self._async_conn: | ||
conn = self._async_conn | ||
self._async_conn = None | ||
await conn.close() | ||
await conn.terminate() | ||
|
||
async def push_transaction_async(self): | ||
"""Increment async transaction depth. | ||
|
@@ -851,19 +852,14 @@ async def aio_execute(self, query): | |
return (await coroutine(query)) | ||
|
||
|
||
############## | ||
# PostgreSQL # | ||
############## | ||
|
||
|
||
class AsyncPostgresqlConnection: | ||
class AioPool(metaclass=abc.ABCMeta): | ||
"""Asynchronous database connection pool. | ||
""" | ||
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs): | ||
self.pool = None | ||
self.loop = loop | ||
self.database = database | ||
self.timeout = timeout or aiopg.DEFAULT_TIMEOUT | ||
self.timeout = timeout | ||
self.connect_params = kwargs | ||
|
||
async def acquire(self): | ||
|
@@ -876,24 +872,20 @@ def release(self, conn): | |
""" | ||
self.pool.release(conn) | ||
|
||
async def connect(self): | ||
@abc.abstractmethod | ||
async def create(self): | ||
"""Create connection pool asynchronously. | ||
""" | ||
self.pool = await aiopg.create_pool( | ||
loop=self.loop, | ||
timeout=self.timeout, | ||
database=self.database, | ||
**self.connect_params) | ||
raise NotImplementedError | ||
|
||
async def close(self): | ||
async def terminate(self): | ||
"""Terminate all pool connections. | ||
""" | ||
self.pool.terminate() | ||
await self.pool.wait_closed() | ||
|
||
async def cursor(self, conn=None, *args, **kwargs): | ||
"""Get a cursor for the specified transaction connection | ||
or acquire from the pool. | ||
"""Get cursor for connection from pool. | ||
""" | ||
in_transaction = conn is not None | ||
if not conn: | ||
|
@@ -914,10 +906,44 @@ async def release_cursor(self, cursor, in_transaction=False): | |
the connection is also released back to the pool. | ||
""" | ||
conn = cursor.connection | ||
cursor.close() | ||
await self.close_cursor(cursor) | ||
if not in_transaction: | ||
self.release(conn) | ||
|
||
@abc.abstractmethod | ||
async def close_cursor(self, cursor): | ||
raise NotImplementedError | ||
|
||
|
||
|
||
############## | ||
# PostgreSQL # | ||
############## | ||
|
||
|
||
class AioPostgresqlPool(AioPool): | ||
"""Asynchronous database connection pool. | ||
""" | ||
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs): | ||
super().__init__( | ||
database=database, | ||
loop=loop, | ||
timeout=timeout or aiopg.DEFAULT_TIMEOUT, | ||
**kwargs, | ||
) | ||
|
||
async def create(self): | ||
"""Create connection pool asynchronously. | ||
""" | ||
self.pool = await aiopg.create_pool( | ||
loop=self.loop, | ||
timeout=self.timeout, | ||
database=self.database, | ||
**self.connect_params) | ||
|
||
async def close_cursor(self, cursor): | ||
cursor.close() | ||
|
||
|
||
class AsyncPostgresqlMixin(AsyncDatabase): | ||
"""Mixin for `peewee.PostgresqlDatabase` providing extra methods | ||
|
@@ -926,7 +952,7 @@ class AsyncPostgresqlMixin(AsyncDatabase): | |
if psycopg2: | ||
Error = psycopg2.Error | ||
|
||
def init_async(self, conn_cls=AsyncPostgresqlConnection, | ||
def init_async(self, conn_cls=AioPostgresqlPool, | ||
enable_json=False, enable_hstore=False): | ||
if not aiopg: | ||
raise Exception("Error, aiopg is not installed!") | ||
|
@@ -1027,27 +1053,11 @@ def use_speedups(self, value): | |
######### | ||
|
||
|
||
class AsyncMySQLConnection: | ||
class AioMysqlPool(AioPool): | ||
"""Asynchronous database connection pool. | ||
""" | ||
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs): | ||
self.pool = None | ||
self.loop = loop | ||
self.database = database | ||
self.timeout = timeout | ||
self.connect_params = kwargs | ||
|
||
async def acquire(self): | ||
"""Acquire connection from pool. | ||
""" | ||
return (await self.pool.acquire()) | ||
|
||
def release(self, conn): | ||
"""Release connection to pool. | ||
""" | ||
self.pool.release(conn) | ||
|
||
async def connect(self): | ||
async def create(self): | ||
"""Create connection pool asynchronously. | ||
""" | ||
self.pool = await aiomysql.create_pool( | ||
|
@@ -1056,37 +1066,8 @@ async def connect(self): | |
connect_timeout=self.timeout, | ||
**self.connect_params) | ||
|
||
async def close(self): | ||
"""Terminate all pool connections. | ||
""" | ||
self.pool.terminate() | ||
await self.pool.wait_closed() | ||
|
||
async def cursor(self, conn=None, *args, **kwargs): | ||
"""Get cursor for connection from pool. | ||
""" | ||
in_transaction = conn is not None | ||
if not conn: | ||
conn = await self.acquire() | ||
try: | ||
cursor = await conn.cursor(*args, **kwargs) | ||
except: | ||
if not in_transaction: | ||
self.release(conn) | ||
raise | ||
cursor.release = functools.partial( | ||
self.release_cursor, cursor, | ||
in_transaction=in_transaction) | ||
return cursor | ||
|
||
async def release_cursor(self, cursor, in_transaction=False): | ||
"""Release cursor coroutine. Unless in transaction, | ||
the connection is also released back to the pool. | ||
""" | ||
conn = cursor.connection | ||
async def close_cursor(self, cursor): | ||
await cursor.close() | ||
if not in_transaction: | ||
self.release(conn) | ||
|
||
|
||
class MySQLDatabase(AsyncDatabase, peewee.MySQLDatabase): | ||
|
@@ -1108,7 +1089,7 @@ def init(self, database, **kwargs): | |
raise Exception("Error, aiomysql is not installed!") | ||
self.min_connections = 1 | ||
self.max_connections = 1 | ||
self._async_conn_cls = kwargs.pop('async_conn', AsyncMySQLConnection) | ||
self._async_conn_cls = kwargs.pop('async_conn', AioMysqlPool) | ||
super().init(database, **kwargs) | ||
|
||
@property | ||
|