Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed GROUP BY issue with expressions #73

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 129 additions & 1 deletion sql_server/pyodbc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

import django
from django.db.models.aggregates import Avg, Count, StdDev, Variance
from django.db.models.expressions import Ref, Subquery, Value
from django.db.models.expressions import Col, Ref, Subquery, Value
from django.db.models.functions import (
Chr, ConcatPair, Greatest, Least, Length, LPad, Repeat, RPad, StrIndex, Substr, Trim
)
from django.db.models.sql import compiler
from django.db.transaction import TransactionManagementError
from django.db.utils import NotSupportedError
from django.utils.hashable import make_hashable


def _as_sql_agv(self, compiler, connection):
Expand Down Expand Up @@ -148,8 +149,135 @@ def _cursor_iter(cursor, sentinel, col_count, itersize):
compiler.cursor_iter = _cursor_iter


def _flatten_expressions_only(node):
"""
Recursively yield this expression and all subexpressions, in
depth-first order.
Copied from django.db.models.expressions.BaseExpression
and added hasattr to prevent some issues.
"""
yield node
for expr in node.get_source_expressions():
if expr and hasattr(expr, 'flatten'):
yield from _flatten_expressions_only(expr)


class SQLCompiler(compiler.SQLCompiler):

def get_group_by(self, select, order_by):
"""
Return a list of 2-tuples of form (sql, params).
The logic of what exactly the GROUP BY clause contains is hard
to describe in other words than "if it passes the test suite,
then it is correct".
"""
# Some examples:
# SomeModel.objects.annotate(Count('somecol'))
# GROUP BY: all fields of the model
#
# SomeModel.objects.values('name').annotate(Count('somecol'))
# GROUP BY: name
#
# SomeModel.objects.annotate(Count('somecol')).values('name')
# GROUP BY: all cols of the model
#
# SomeModel.objects.values('name', 'pk').annotate(Count('somecol')).values('pk')
# GROUP BY: name, pk
#
# SomeModel.objects.values('name').annotate(Count('somecol')).values('pk')
# GROUP BY: name, pk
#
# In fact, the self.query.group_by is the minimal set to GROUP BY. It
# can't be ever restricted to a smaller set, but additional columns in
# HAVING, ORDER BY, and SELECT clauses are added to it. Unfortunately
# the end result is that it is impossible to force the query to have
# a chosen GROUP BY clause - you can almost do this by using the form:
# .values(*wanted_cols).annotate(AnAggregate())
# but any later annotations, extra selects, values calls that
# refer some column outside of the wanted_cols, order_by, or even
# filter calls can alter the GROUP BY clause.

# The query.group_by is either None (no GROUP BY at all), True
# (group by select fields), or a list of expressions to be added
# to the group by.
if self.query.group_by is None:
return []

expressions = []
if self.query.group_by is not True:
# If the group by is set to a list (by .values() call most likely),
# then we need to add everything in it to the GROUP BY clause.
# Backwards compatibility hack for setting query.group_by. Remove
# when we have public API way of forcing the GROUP BY clause.
# Converts string references to expressions.
for expr in self.query.group_by:
if not hasattr(expr, 'as_sql'):
expressions.append(self.query.resolve_ref(expr))
else:
expressions.append(expr)
# Note that even if the group_by is set, it is only the minimal
# set to group by. So, we need to add cols in select, order_by, and
# having into the select in any case.
if django.VERSION >= (3, 0, 0):
ref_sources = {
expr.source for expr in expressions if isinstance(expr, Ref)
}
for expr, _, _ in select:
cols = expr.get_group_by_cols()

if django.VERSION >= (3, 0, 0):
# Skip members of the select clause that are already included
# by reference.
if expr in ref_sources:
continue

for col in cols:
expressions.append(col)
# for MSSQL, the stored procedure used by pyodbc doesn't allow
# queries in the format:
# SELECT [AGGREGATE_FUNCTION](*) FROM (SELECT id, F(a) FROM ... GROUP BY id, F(a)) subquery
# the accepted format is:
# SELECT [AGGREGATE_FUNCTION](*) FROM (SELECT id, F(a) FROM ... GROUP BY id, F(a), a) subquery
# Therefore we add the referenced columns in the get_group_by function
for sub_expr in _flatten_expressions_only(expr):
if isinstance(sub_expr, Col):
expr_cols = sub_expr.get_group_by_cols()
for expr_col in expr_cols:
expressions.append(expr_col)
for expr, (sql, params, is_ref) in order_by:
if django.VERSION >= (3, 0, 0):
# Skip References to the select clause, as all expressions in the
# select clause are already part of the group by.
if not is_ref:
expressions.extend(expr.get_group_by_cols())
else:
# Skip References to the select clause, as all expressions in the
# select clause are already part of the group by.
if not expr.contains_aggregate and not is_ref:
expressions.extend(expr.get_source_expressions())
having_group_by = self.having.get_group_by_cols() if self.having else ()
for expr in having_group_by:
expressions.append(expr)
result = []
seen = set()
expressions = self.collapse_group_by(expressions, having_group_by)

for expr in expressions:
sql, params = self.compile(expr)
if django.VERSION >= (3, 0, 0):
sql, params = expr.select_format(self, sql, params)
else:
if isinstance(expr, Subquery) and not sql.startswith('('):
# Subquery expression from HAVING clause may not contain
# wrapping () because they could be removed when a subquery is
# the "rhs" in an expression (see Subquery._prepare()).
sql = '(%s)' % sql
params_hash = make_hashable(params)
if (sql, params_hash) not in seen:
result.append((sql, params))
seen.add((sql, params_hash))
return result

def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Create the SQL for this query. Return the SQL string and list of
Expand Down
13 changes: 12 additions & 1 deletion testapp/tests/test_expressions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from unittest import skipUnless

from django import VERSION
from django.db.models import IntegerField
from django.db.models import IntegerField, CharField
from django.db.models.expressions import Case, Exists, OuterRef, Subquery, Value, When
from django.test import TestCase
from django.db.models.functions import Concat

from ..models import Author, Comment, Post

Expand Down Expand Up @@ -32,6 +33,16 @@ def test_with_count(self):
post_exists=Exists(Post.objects.all())
).filter(post_exists=True).count()

def test_simple_concat_with_count(self):
Post.objects.annotate(
display_name=Concat('title', Value(' test'))
).count()

def test_concat_with_count(self):
Post.objects.annotate(
display_name=Concat('title', Value(', '), 'author', output_field=CharField())
).count()

@skipUnless(DJANGO3, "Django 3 specific tests")
def test_with_case_when(self):
author = Author.objects.annotate(
Expand Down