diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 3b57e6e9..2351b7d4 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -15,6 +15,21 @@ jobs: image: amazon/dynamodb-local ports: - 8000:8000 + + postgresql: + image: postgres:latest + ports: + - 5433:5432 + env: + POSTGRES_PASSWORD: pwd + POSTGRES_USER: root + POSTGRES_DB: dummy + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 steps: - uses: actions/checkout@v4 - uses: supercharge/redis-github-action@1.5.0 diff --git a/docker-compose.yml b/docker-compose.yml index 7864e9d7..3463412d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -30,6 +30,17 @@ services: ports: - "11211:11211" + postgres: + image: postgres:latest + environment: + - POSTGRES_USER=root + - POSTGRES_PASSWORD=pwd + - POSTGRES_DB=dummy + ports: + - "5433:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + volumes: postgres_data: mongo_data: diff --git a/docs/api.rst b/docs/api.rst index b31eceb5..faae610d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -20,4 +20,5 @@ Anything documented here is part of the public API that Flask-Session provides, .. autoclass:: flask_session.cachelib.CacheLibSessionInterface .. autoclass:: flask_session.mongodb.MongoDBSessionInterface .. autoclass:: flask_session.sqlalchemy.SqlAlchemySessionInterface -.. autoclass:: flask_session.dynamodb.DynamoDBSessionInterface \ No newline at end of file +.. autoclass:: flask_session.dynamodb.DynamoDBSessionInterface +.. autoclass:: flask_session.postgresql.PostgreSqlSessionInterface \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 1d3e2326..da444405 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,4 +88,5 @@ dev-dependencies = [ "boto3>=1.34.68", "mypy_boto3_dynamodb>=1.34.67", "pymemcache>=4.0.0", + "psycopg2-binary>=2", ] diff --git a/requirements/dev.txt b/requirements/dev.txt index 14205475..868d1b06 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -17,4 +17,5 @@ Flask-SQLAlchemy pymongo boto3 mypy_boto3_dynamodb +psycopg2-binary diff --git a/requirements/docs.in b/requirements/docs.in index 211cd708..b8088dd3 100644 --- a/requirements/docs.in +++ b/requirements/docs.in @@ -9,4 +9,5 @@ pymongo flask_sqlalchemy pymemcache boto3 -mypy_boto3_dynamodb \ No newline at end of file +mypy_boto3_dynamodb +psycopg2-binary \ No newline at end of file diff --git a/requirements/docs.txt b/requirements/docs.txt index 81adad9c..0d61012b 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -6,6 +6,8 @@ # alabaster==0.7.13 # via sphinx +async-timeout==4.0.3 + # via redis babel==2.12.1 # via sphinx beautifulsoup4==4.12.3 @@ -36,6 +38,8 @@ flask-sqlalchemy==3.1.1 # via -r requirements/docs.in furo==2024.1.29 # via -r requirements/docs.in +greenlet==3.0.3 + # via sqlalchemy idna==3.4 # via requests imagesize==1.4.1 @@ -58,6 +62,8 @@ mypy-boto3-dynamodb==1.34.67 # via -r requirements/docs.in packaging==23.1 # via sphinx +psycopg2-binary==2.9.9 + # via -r requirements/docs.in pygments==2.15.1 # via # furo diff --git a/src/flask_session/__init__.py b/src/flask_session/__init__.py index b95a0b05..f527beb5 100644 --- a/src/flask_session/__init__.py +++ b/src/flask_session/__init__.py @@ -100,9 +100,6 @@ def _get_interface(self, app): SESSION_SQLALCHEMY_BIND_KEY = config.get( "SESSION_SQLALCHEMY_BIND_KEY", Defaults.SESSION_SQLALCHEMY_BIND_KEY ) - SESSION_CLEANUP_N_REQUESTS = config.get( - "SESSION_CLEANUP_N_REQUESTS", Defaults.SESSION_CLEANUP_N_REQUESTS - ) # DynamoDB settings SESSION_DYNAMODB = config.get("SESSION_DYNAMODB", Defaults.SESSION_DYNAMODB) @@ -110,6 +107,22 @@ def _get_interface(self, app): "SESSION_DYNAMODB_TABLE", Defaults.SESSION_DYNAMODB_TABLE ) + # PostgreSQL settings + SESSION_POSTGRESQL = config.get( + "SESSION_POSTGRESQL", Defaults.SESSION_POSTGRESQL + ) + SESSION_POSTGRESQL_TABLE = config.get( + "SESSION_POSTGRESQL_TABLE", Defaults.SESSION_POSTGRESQL_TABLE + ) + SESSION_POSTGRESQL_SCHEMA = config.get( + "SESSION_POSTGRESQL_SCHEMA", Defaults.SESSION_POSTGRESQL_SCHEMA + ) + + # Shared settings + SESSION_CLEANUP_N_REQUESTS = config.get( + "SESSION_CLEANUP_N_REQUESTS", Defaults.SESSION_CLEANUP_N_REQUESTS + ) + common_params = { "app": app, "key_prefix": SESSION_KEY_PREFIX, @@ -180,6 +193,17 @@ def _get_interface(self, app): table_name=SESSION_DYNAMODB_TABLE, ) + elif SESSION_TYPE == "postgresql": + from .postgresql import PostgreSqlSessionInterface + + session_interface = PostgreSqlSessionInterface( + **common_params, + pool=SESSION_POSTGRESQL, + table=SESSION_POSTGRESQL_TABLE, + schema=SESSION_POSTGRESQL_SCHEMA, + cleanup_n_requests=SESSION_CLEANUP_N_REQUESTS, + ) + else: raise ValueError(f"Unrecognized value for SESSION_TYPE: {SESSION_TYPE}") diff --git a/src/flask_session/defaults.py b/src/flask_session/defaults.py index 7f890d6e..0467db46 100644 --- a/src/flask_session/defaults.py +++ b/src/flask_session/defaults.py @@ -43,3 +43,8 @@ class Defaults: # DynamoDB settings SESSION_DYNAMODB = None SESSION_DYNAMODB_TABLE = "Sessions" + + # PostgreSQL settings + SESSION_POSTGRESQL = None + SESSION_POSTGRESQL_TABLE = "flask_sessions" + SESSION_POSTGRESQL_SCHEMA = "public" diff --git a/src/flask_session/postgresql/__init__.py b/src/flask_session/postgresql/__init__.py new file mode 100644 index 00000000..0d51b3a8 --- /dev/null +++ b/src/flask_session/postgresql/__init__.py @@ -0,0 +1 @@ +from .postgresql import PostgreSqlSession, PostgreSqlSessionInterface # noqa: F401 diff --git a/src/flask_session/postgresql/_queries.py b/src/flask_session/postgresql/_queries.py new file mode 100644 index 00000000..23fb0101 --- /dev/null +++ b/src/flask_session/postgresql/_queries.py @@ -0,0 +1,84 @@ +from psycopg2 import sql + + +class Queries: + def __init__(self, schema: str, table: str) -> None: + """Class to hold all the queries used by the session interface. + + Args: + schema (str): The name of the schema to use for the session data. + table (str): The name of the table to use for the session data. + """ + self.schema = schema + self.table = table + + @property + def create_schema(self) -> str: + return sql.SQL("CREATE SCHEMA IF NOT EXISTS {schema};").format( + schema=sql.Identifier(self.schema) + ) + + @property + def create_table(self) -> str: + uq_idx = sql.Identifier(f"uq_{self.table}_session_id") + expiry_idx = sql.Identifier(f"{self.table}_expiry_idx") + return sql.SQL( + """CREATE TABLE IF NOT EXISTS {schema}.{table} ( + session_id VARCHAR(255) NOT NULL PRIMARY KEY, + created TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() AT TIME ZONE 'utc'), + data BYTEA, + expiry TIMESTAMP WITHOUT TIME ZONE + ); + + --- Unique session_id + CREATE UNIQUE INDEX IF NOT EXISTS + {uq_idx} ON {schema}.{table} (session_id); + + --- Index for expiry timestamp + CREATE INDEX IF NOT EXISTS + {expiry_idx} ON {schema}.{table} (expiry);""" + ).format( + schema=sql.Identifier(self.schema), + table=sql.Identifier(self.table), + uq_idx=uq_idx, + expiry_idx=expiry_idx, + ) + + @property + def retrieve_session_data(self) -> str: + return sql.SQL( + """--- If the current sessions is expired, delete it + DELETE FROM {schema}.{table} + WHERE session_id = %(session_id)s AND expiry < NOW(); + --- Else retrieve it + SELECT data FROM {schema}.{table} WHERE session_id = %(session_id)s; + """ + ).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)) + + @property + def upsert_session(self) -> str: + return sql.SQL( + """INSERT INTO {schema}.{table} (session_id, data, expiry) + VALUES (%(session_id)s, %(data)s, NOW() + %(ttl)s) + ON CONFLICT (session_id) + DO UPDATE SET data = %(data)s, expiry = NOW() + %(ttl)s; + """ + ).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)) + + @property + def delete_expired_sessions(self) -> str: + return sql.SQL("DELETE FROM {schema}.{table} WHERE expiry < NOW();").format( + schema=sql.Identifier(self.schema), table=sql.Identifier(self.table) + ) + + @property + def delete_session(self) -> str: + return sql.SQL( + "DELETE FROM {schema}.{table} WHERE session_id = %(session_id)s;" + ).format(schema=sql.Identifier(self.schema), table=sql.Identifier(self.table)) + + @property + def drop_sessions_table(self) -> str: + return sql.SQL("DROP TABLE IF EXISTS {schema}.{table};").format( + schema=sql.Identifier(self.schema), table=sql.Identifier(self.table) + ) diff --git a/src/flask_session/postgresql/postgresql.py b/src/flask_session/postgresql/postgresql.py new file mode 100644 index 00000000..9ad198bb --- /dev/null +++ b/src/flask_session/postgresql/postgresql.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from contextlib import contextmanager +from datetime import timedelta as TimeDelta +from typing import Generator, Optional + +from flask import Flask +from itsdangerous import want_bytes +from psycopg2.extensions import connection as PsycoPg2Connection +from psycopg2.extensions import cursor as PsycoPg2Cursor +from psycopg2.pool import ThreadedConnectionPool + +from .._utils import retry_query +from ..base import ServerSideSession, ServerSideSessionInterface +from ..defaults import Defaults +from ._queries import Queries + + +class PostgreSqlSession(ServerSideSession): + pass + + +class PostgreSqlSessionInterface(ServerSideSessionInterface): + """A Session interface that uses PostgreSQL as a session storage. (`psycopg2` required) + + :param pool: A ``psycopg2.pool.ThreadedConnectionPool`` instance. + :param key_prefix: A prefix that is added to all storage keys. + :param use_signer: Whether to sign the session id cookie or not. + :param permanent: Whether to use permanent session or not. + :param sid_length: The length of the generated session id in bytes. + :param serialization_format: The serialization format to use for the session data. + :param table: The table name you want to use. + :param schema: The db schema to use. + :param cleanup_n_requests: Delete expired sessions on average every N requests. + """ + + session_class = PostgreSqlSession + ttl = False + + def __init__( + self, + app: Flask, + pool: Optional[ThreadedConnectionPool] = Defaults.SESSION_POSTGRESQL, + key_prefix: str = Defaults.SESSION_KEY_PREFIX, + use_signer: bool = Defaults.SESSION_USE_SIGNER, + permanent: bool = Defaults.SESSION_PERMANENT, + sid_length: int = Defaults.SESSION_ID_LENGTH, + serialization_format: str = Defaults.SESSION_SERIALIZATION_FORMAT, + table: str = Defaults.SESSION_POSTGRESQL_TABLE, + schema: str = Defaults.SESSION_POSTGRESQL_SCHEMA, + cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS, + ) -> None: + if not isinstance(pool, ThreadedConnectionPool): + raise TypeError("No valid ThreadedConnectionPool instance provided.") + + self.pool = pool + + self._table = table + self._schema = schema + + self._queries = Queries(schema=self._schema, table=self._table) + + self._create_schema_and_table() + + super().__init__( + app, + key_prefix, + use_signer, + permanent, + sid_length, + serialization_format, + cleanup_n_requests, + ) + + @contextmanager + def _get_cursor( + self, conn: Optional[PsycoPg2Connection] = None + ) -> Generator[PsycoPg2Cursor, None, None]: + _conn: PsycoPg2Connection = conn or self.pool.getconn() + + assert isinstance(_conn, PsycoPg2Connection) + try: + with _conn, _conn.cursor() as cur: + yield cur + except Exception: + raise + finally: + self.pool.putconn(_conn) + + @retry_query(max_attempts=3) + def _create_schema_and_table(self) -> None: + with self._get_cursor() as cur: + cur.execute(self._queries.create_schema) + cur.execute(self._queries.create_table) + + def _delete_expired_sessions(self) -> None: + """Delete all expired sessions from the database.""" + with self._get_cursor() as cur: + cur.execute(self._queries.delete_expired_sessions) + + @retry_query(max_attempts=3) + def _delete_session(self, store_id: str) -> None: + with self._get_cursor() as cur: + cur.execute( + self._queries.delete_session, + dict(session_id=store_id), + ) + + @retry_query(max_attempts=3) + def _retrieve_session_data(self, store_id: str) -> Optional[dict]: + with self._get_cursor() as cur: + cur.execute( + self._queries.retrieve_session_data, + dict(session_id=store_id), + ) + session_data = cur.fetchone() + + if session_data is not None: + serialized_session_data = want_bytes(session_data[0]) + return self.serializer.decode(serialized_session_data) + return None + + @retry_query(max_attempts=3) + def _upsert_session( + self, session_lifetime: TimeDelta, session: ServerSideSession, store_id: str + ) -> None: + + serialized_session_data = self.serializer.encode(session) + + if session.sid is not None: + assert session.sid == store_id.removeprefix(self.key_prefix) + + with self._get_cursor() as cur: + cur.execute( + self._queries.upsert_session, + dict( + session_id=store_id, + data=serialized_session_data, + ttl=session_lifetime, + ), + ) + + def _drop_table(self) -> None: + with self._get_cursor() as cur: + cur.execute(self._queries.drop_sessions_table) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py new file mode 100644 index 00000000..05a9e555 --- /dev/null +++ b/tests/test_postgresql.py @@ -0,0 +1,50 @@ +import json +from contextlib import contextmanager + +import flask +from itsdangerous import want_bytes +from psycopg2.pool import ThreadedConnectionPool + +from flask_session.postgresql import PostgreSqlSession + +TEST_DB = "postgresql://root:pwd@localhost:5433/dummy" + + +class TestPostgreSql: + """This requires package: sqlalchemy""" + + @contextmanager + def setup_postgresql(self, app_utils): + self.pool = ThreadedConnectionPool(1, 5, TEST_DB) + self.app = app_utils.create_app( + {"SESSION_TYPE": "postgresql", "SESSION_POSTGRESQL": self.pool} + ) + + yield + self.app.session_interface._drop_table() + + def retrieve_stored_session(self, key): + with self.app.session_interface._get_cursor() as cur: + cur.execute( + self.app.session_interface._queries.retrieve_session_data, + dict(session_id=key), + ) + + session_data = cur.fetchone() + if session_data is not None: + return want_bytes(session_data[0].tobytes()) + return None + + def test_postgresql(self, app_utils): + with self.setup_postgresql(app_utils), self.app.test_request_context(): + assert isinstance(flask.session, PostgreSqlSession) + app_utils.test_session(self.app) + + # Check if the session is stored in MongoDB + cookie = app_utils.test_session_with_cookie(self.app) + session_id = cookie.split(";")[0].split("=")[1] + byte_string = self.retrieve_stored_session(f"session:{session_id}") + stored_session = ( + json.loads(byte_string.decode("utf-8")) if byte_string else {} + ) + assert stored_session.get("value") == "44"