From 91c6f6a4ea883545c91dcc1b2f5bfb85eb9c832d Mon Sep 17 00:00:00 2001 From: Damien Baldy Date: Sun, 20 Sep 2020 21:31:10 +0200 Subject: [PATCH] fixed GROUP BY issue with expressions --- sql_server/pyodbc/compiler.py | 130 +++++++++++++++++++++++++++++- testapp/tests/test_expressions.py | 13 ++- 2 files changed, 141 insertions(+), 2 deletions(-) diff --git a/sql_server/pyodbc/compiler.py b/sql_server/pyodbc/compiler.py index 86fecb3e..1799447d 100644 --- a/sql_server/pyodbc/compiler.py +++ b/sql_server/pyodbc/compiler.py @@ -3,13 +3,14 @@ import django from django.db.models.aggregates import Avg, Count, StdDev, Variance -from django.db.models.expressions import Ref, Subquery, Value +from django.db.models.expressions import Col, Ref, Subquery, Value from django.db.models.functions import ( Chr, ConcatPair, Greatest, Least, Length, LPad, Repeat, RPad, StrIndex, Substr, Trim ) from django.db.models.sql import compiler from django.db.transaction import TransactionManagementError from django.db.utils import NotSupportedError +from django.utils.hashable import make_hashable def _as_sql_agv(self, compiler, connection): @@ -148,8 +149,135 @@ def _cursor_iter(cursor, sentinel, col_count, itersize): compiler.cursor_iter = _cursor_iter +def _flatten_expressions_only(node): + """ + Recursively yield this expression and all subexpressions, in + depth-first order. + Copied from django.db.models.expressions.BaseExpression + and added hasattr to prevent some issues. + """ + yield node + for expr in node.get_source_expressions(): + if expr and hasattr(expr, 'flatten'): + yield from _flatten_expressions_only(expr) + + class SQLCompiler(compiler.SQLCompiler): + def get_group_by(self, select, order_by): + """ + Return a list of 2-tuples of form (sql, params). + The logic of what exactly the GROUP BY clause contains is hard + to describe in other words than "if it passes the test suite, + then it is correct". + """ + # Some examples: + # SomeModel.objects.annotate(Count('somecol')) + # GROUP BY: all fields of the model + # + # SomeModel.objects.values('name').annotate(Count('somecol')) + # GROUP BY: name + # + # SomeModel.objects.annotate(Count('somecol')).values('name') + # GROUP BY: all cols of the model + # + # SomeModel.objects.values('name', 'pk').annotate(Count('somecol')).values('pk') + # GROUP BY: name, pk + # + # SomeModel.objects.values('name').annotate(Count('somecol')).values('pk') + # GROUP BY: name, pk + # + # In fact, the self.query.group_by is the minimal set to GROUP BY. It + # can't be ever restricted to a smaller set, but additional columns in + # HAVING, ORDER BY, and SELECT clauses are added to it. Unfortunately + # the end result is that it is impossible to force the query to have + # a chosen GROUP BY clause - you can almost do this by using the form: + # .values(*wanted_cols).annotate(AnAggregate()) + # but any later annotations, extra selects, values calls that + # refer some column outside of the wanted_cols, order_by, or even + # filter calls can alter the GROUP BY clause. + + # The query.group_by is either None (no GROUP BY at all), True + # (group by select fields), or a list of expressions to be added + # to the group by. + if self.query.group_by is None: + return [] + + expressions = [] + if self.query.group_by is not True: + # If the group by is set to a list (by .values() call most likely), + # then we need to add everything in it to the GROUP BY clause. + # Backwards compatibility hack for setting query.group_by. Remove + # when we have public API way of forcing the GROUP BY clause. + # Converts string references to expressions. + for expr in self.query.group_by: + if not hasattr(expr, 'as_sql'): + expressions.append(self.query.resolve_ref(expr)) + else: + expressions.append(expr) + # Note that even if the group_by is set, it is only the minimal + # set to group by. So, we need to add cols in select, order_by, and + # having into the select in any case. + if django.VERSION >= (3, 0, 0): + ref_sources = { + expr.source for expr in expressions if isinstance(expr, Ref) + } + for expr, _, _ in select: + cols = expr.get_group_by_cols() + + if django.VERSION >= (3, 0, 0): + # Skip members of the select clause that are already included + # by reference. + if expr in ref_sources: + continue + + for col in cols: + expressions.append(col) + # for MSSQL, the stored procedure used by pyodbc doesn't allow + # queries in the format: + # SELECT [AGGREGATE_FUNCTION](*) FROM (SELECT id, F(a) FROM ... GROUP BY id, F(a)) subquery + # the accepted format is: + # SELECT [AGGREGATE_FUNCTION](*) FROM (SELECT id, F(a) FROM ... GROUP BY id, F(a), a) subquery + # Therefore we add the referenced columns in the get_group_by function + for sub_expr in _flatten_expressions_only(expr): + if isinstance(sub_expr, Col): + expr_cols = sub_expr.get_group_by_cols() + for expr_col in expr_cols: + expressions.append(expr_col) + for expr, (sql, params, is_ref) in order_by: + if django.VERSION >= (3, 0, 0): + # Skip References to the select clause, as all expressions in the + # select clause are already part of the group by. + if not is_ref: + expressions.extend(expr.get_group_by_cols()) + else: + # Skip References to the select clause, as all expressions in the + # select clause are already part of the group by. + if not expr.contains_aggregate and not is_ref: + expressions.extend(expr.get_source_expressions()) + having_group_by = self.having.get_group_by_cols() if self.having else () + for expr in having_group_by: + expressions.append(expr) + result = [] + seen = set() + expressions = self.collapse_group_by(expressions, having_group_by) + + for expr in expressions: + sql, params = self.compile(expr) + if django.VERSION >= (3, 0, 0): + sql, params = expr.select_format(self, sql, params) + else: + if isinstance(expr, Subquery) and not sql.startswith('('): + # Subquery expression from HAVING clause may not contain + # wrapping () because they could be removed when a subquery is + # the "rhs" in an expression (see Subquery._prepare()). + sql = '(%s)' % sql + params_hash = make_hashable(params) + if (sql, params_hash) not in seen: + result.append((sql, params)) + seen.add((sql, params_hash)) + return result + def as_sql(self, with_limits=True, with_col_aliases=False): """ Create the SQL for this query. Return the SQL string and list of diff --git a/testapp/tests/test_expressions.py b/testapp/tests/test_expressions.py index 90623753..25667905 100644 --- a/testapp/tests/test_expressions.py +++ b/testapp/tests/test_expressions.py @@ -1,9 +1,10 @@ from unittest import skipUnless from django import VERSION -from django.db.models import IntegerField +from django.db.models import IntegerField, CharField from django.db.models.expressions import Case, Exists, OuterRef, Subquery, Value, When from django.test import TestCase +from django.db.models.functions import Concat from ..models import Author, Comment, Post @@ -32,6 +33,16 @@ def test_with_count(self): post_exists=Exists(Post.objects.all()) ).filter(post_exists=True).count() + def test_simple_concat_with_count(self): + Post.objects.annotate( + display_name=Concat('title', Value(' test')) + ).count() + + def test_concat_with_count(self): + Post.objects.annotate( + display_name=Concat('title', Value(', '), 'author', output_field=CharField()) + ).count() + @skipUnless(DJANGO3, "Django 3 specific tests") def test_with_case_when(self): author = Author.objects.annotate(