diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8d0befeee0..7b1f19d67e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,8 +2,7 @@ name: Test on: pull_request: push: - branches: - - master + env: PY_COLORS: 1 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4e2a2db260..5579f08524 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -3,8 +3,6 @@ run-name: Lint code on: pull_request: push: - branches: - - master env: PYTHON_VERSION: 3.8 diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 566c116314..92d73aa8fc 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -20,11 +20,12 @@ import os import re import sqlite3 +import sys import threading import time from abc import ABC from collections import defaultdict -from sqlite3 import Connection +from sqlite3 import Connection, sqlite_version_info from types import TracebackType from typing import ( Any, @@ -73,6 +74,16 @@ class DBAccessError(Exception): """ +class DBCustomFunctionError(Exception): + """A sqlite function registered by beets failed.""" + + def __init__(self): + super().__init__( + "beets defined SQLite function failed; " + "see the other errors above for details" + ) + + class FormattedMapping(Mapping[str, str]): """A `dict`-like formatted view of a model. @@ -970,6 +981,12 @@ def __exit__( self._mutated = False self.db._db_lock.release() + if ( + isinstance(exc_value, sqlite3.OperationalError) + and exc_value.args[0] == "user-defined function raised exception" + ): + raise DBCustomFunctionError() + def query(self, statement: str, subvals: Sequence = ()) -> List: """Execute an SQL statement with substitution values and return a list of rows from the database. @@ -1028,6 +1045,10 @@ def __init__(self, path, timeout: float = 5.0): "sqlite3 must be compiled with multi-threading support" ) + # Print tracebacks for exceptions in user defined functions + # See also `self.add_functions` and `DBCustomFunctionError`. + sqlite3.enable_callback_tracebacks(True) + self.path = path self.timeout = timeout @@ -1123,9 +1144,14 @@ def bytelower(bytestring: Optional[AnyStr]) -> Optional[AnyStr]: return bytestring - conn.create_function("regexp", 2, regexp) - conn.create_function("unidecode", 1, unidecode) - conn.create_function("bytelower", 1, bytelower) + deterministic = {} + if sys.version_info >= (3, 8) and sqlite_version_info >= (3, 8, 3): + # Let sqlite make extra optimizations + deterministic["deterministic"] = True + + conn.create_function("regexp", 2, regexp, **deterministic) + conn.create_function("unidecode", 1, unidecode, **deterministic) + conn.create_function("bytelower", 1, bytelower, **deterministic) def _close(self): """Close the all connections to the underlying SQLite database