diff --git a/sql_server/pyodbc/schema.py b/sql_server/pyodbc/schema.py index 9abcbd04..262d4c81 100644 --- a/sql_server/pyodbc/schema.py +++ b/sql_server/pyodbc/schema.py @@ -18,6 +18,8 @@ from django.db.transaction import TransactionManagementError from django.utils.encoding import force_str +from collections import defaultdict + class Statement(DjStatement): def __hash__(self): @@ -68,6 +70,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_create_unique_null = "CREATE UNIQUE INDEX %(name)s ON %(table)s(%(columns)s) " \ "WHERE %(columns)s IS NOT NULL" + _deferred_unique_indexes = defaultdict(list) + def _alter_column_default_sql(self, model, old_field, new_field, drop=False): """ Hook to specialize column default alteration. @@ -236,6 +240,15 @@ def alter_db_table(self, model, old_db_table, new_db_table): return super().alter_db_table(model, old_db_table, new_db_table) + def _delete_deferred_unique_indexes_for_field(self, field): + deferred_statements = self._deferred_unique_indexes.get(str(field), []) + for stmt in deferred_statements: + if stmt in self.deferred_sql: + self.deferred_sql.remove(stmt) + + def _add_deferred_unique_index_for_field(self, field, statement): + self._deferred_unique_indexes[str(field)].append(statement) + def _alter_field(self, model, old_field, new_field, old_type, new_type, old_db_params, new_db_params, strict=False): """Actually perform a "physical" (non-ManyToMany) field update.""" @@ -449,6 +462,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, ) else: self.execute(self._create_unique_sql(model, [new_field.column])) + self._delete_deferred_unique_indexes_for_field(new_field) # Added an index? # constraint will no longer be used in lieu of an index. The following # lines from the truth table show all True cases; the rest are False: @@ -477,6 +491,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, ) else: self.execute(self._create_unique_sql(model, columns=[old_field.column])) + self._delete_deferred_unique_indexes_for_field(old_field) else: for fields in model._meta.unique_together: columns = [model._meta.get_field(field).column for field in fields] @@ -644,9 +659,11 @@ def add_field(self, model, field): not field.many_to_many and field.null and field.unique): definition = definition.replace(' UNIQUE', '') - self.deferred_sql.append(self._create_index_sql( + statement = self._create_index_sql( model, [field], sql=self.sql_create_unique_null, suffix="_uniq" - )) + ) + self.deferred_sql.append(statement) + self._add_deferred_unique_index_for_field(field, statement) # Check constraints can go on the column SQL here db_params = field.db_parameters(connection=self.connection) @@ -750,9 +767,11 @@ def create_model(self, model): not field.many_to_many and field.null and field.unique): definition = definition.replace(' UNIQUE', '') - self.deferred_sql.append(self._create_index_sql( + statement = self._create_index_sql( model, [field], sql=self.sql_create_unique_null, suffix="_uniq" - )) + ) + self.deferred_sql.append(statement) + self._add_deferred_unique_index_for_field(field, statement) # Check constraints can go on the column SQL here db_params = field.db_parameters(connection=self.connection) diff --git a/testapp/migrations/0008_test_alter_nullable_in_unique_field.py b/testapp/migrations/0008_test_alter_nullable_in_unique_field.py new file mode 100644 index 00000000..921a426d --- /dev/null +++ b/testapp/migrations/0008_test_alter_nullable_in_unique_field.py @@ -0,0 +1,24 @@ +# Generated by Django 3.0.4 on 2020-04-20 14:59 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('testapp', '0007_test_remove_onetoone_field_part2'), + ] + + operations = [ + migrations.CreateModel( + name='TestAlterNullableInUniqueField', + fields=[ + ('a', models.CharField(max_length=50, null=True, unique=True)), + ], + ), + migrations.AlterField( + model_name='testalternullableinuniquefield', + name='a', + field=models.CharField(max_length=50, unique=True), + ) + ] diff --git a/testapp/models.py b/testapp/models.py index c87f797b..4357d91f 100644 --- a/testapp/models.py +++ b/testapp/models.py @@ -71,3 +71,11 @@ class TestRemoveOneToOneFieldModel(models.Model): # thats already is removed. # b = models.OneToOneField('self', on_delete=models.SET_NULL, null=True) a = models.CharField(max_length=50) + + +class TestAlterNullableInUniqueField(models.Model): + """ Model used to test a single migration that creates a field with unique=True and null=True and then alters + the field to set null=False. This is a common use case when you want to add a non-nullable unique field to a + pre-existing model. In order to make that work you need to first create the unique field as nullable, then + populate the field for every pre-existing instance, and then alter the field to set it to non-nullaable. """ + a = models.CharField(max_length=50, unique=True, null=True)