From 4ee0ba8260bd7401e339d43d5eae9d8e0258623b Mon Sep 17 00:00:00 2001 From: David <2297074+dwasyl@users.noreply.github.com> Date: Wed, 13 Nov 2019 12:51:00 -0700 Subject: [PATCH] Apply fix for count of group by query --- sql_server/pyodbc/base.py | 56 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/sql_server/pyodbc/base.py b/sql_server/pyodbc/base.py index 8e1a5385..4cf3063d 100644 --- a/sql_server/pyodbc/base.py +++ b/sql_server/pyodbc/base.py @@ -4,6 +4,9 @@ import os import re import time +import datetime +from decimal import Decimal +from uuid import UUID from django.core.exceptions import ImproperlyConfigured from django import VERSION @@ -484,6 +487,36 @@ def __init__(self, cursor, connection): self.last_sql = '' self.last_params = () + def _pytype_to_sqltype(self, typ, value): + if value is None: + return 'INT' + elif isinstance(value, str): + length = len(value) + if length == 0: + return 'NVARCHAR' + return 'NVARCHAR(%s)' % len(value) + elif typ == int: + if value < 0x7FFFFFFF and value > -0x7FFFFFFF: + return 'INT' + else: + return 'BIGINT' + elif typ == float: + return 'FLOAT' + elif typ == bool: + return 'BIT' + elif isinstance(value, Decimal): + return 'NUMERIC' + elif isinstance(value, datetime.date): + return 'DATE' + elif isinstance(value, datetime.time): + return 'TIME' + elif isinstance(value, datetime.datetime): + return 'TIMESTAMP' + elif isinstance(value, UUID): + return 'uniqueidentifier' + else: + raise NotImplementedError('not support type %s (%s)' % (type(value), repr(value))) + def close(self): if self.active: self.active = False @@ -527,8 +560,31 @@ def format_params(self, params): return tuple(fp) + def _fix_for_params(self, query, params, unify_by_values=False): + if params is None: + params = [] + query = query + elif unify_by_values and len(params) > 0: + params = [(param, type(param)) for param in params] + params_dict = {param: '@arg%d' % i for i, param in enumerate(set(params))} + args = [params_dict[param] for param in params] + + variables = [] + params = [] + for key, value in params_dict.items(): + datatype = self._pytype_to_sqltype(key[1], key[0]) + variables.append("%s %s = %%s " % (value, datatype)) + params.append(key[0]) + query = ('DECLARE %s \n' % ','.join(variables)) + (query % tuple(args)) + params = tuple(params) + return query, params + def execute(self, sql, params=None): self.last_sql = sql + if 'GROUP BY' in sql: + sql, params = self._fix_for_params(sql, params, unify_by_values=True) + # print ('sql is %s' % sql) + # print ('params is %s' % repr(params)) sql = self.format_sql(sql, params) params = self.format_params(params) self.last_params = params