diff --git a/CHANGELOG.md b/CHANGELOG.md index 740ebe2fa..d79e410fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ https://github.com/coleifer/peewee/releases ## master * Add bitwise and other helper methods to `BigBitField`, #2802. +* Add `add_column_default` and `drop_column_default` migrator methods for + specifying a server-side default value, #2803. * The new `star` attribute was causing issues for users who had a field named star on their models. This attribute is now renamed to `__star__`. #2796. diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index 168d5c8e4..02023dbac 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -3116,6 +3116,31 @@ Adding or dropping table constraints: # Add a UNIQUE constraint on the first and last names. migrate(migrator.add_unique('person', 'first_name', 'last_name')) +Adding or dropping a database-level default value for a column: + +.. code-block:: python + + # Add a default value for a status column. + migrate(migrator.add_column_default( + 'entries', + 'status', + 'draft')) + + # Remove the default. + migrate(migrator.drop_column_default('entries', 'status')) + + # Use a function for the default value (does not work with Sqlite): + migrate(migrator.add_column_default( + 'entries', + 'timestamp', + fn.now())) + + # Or alternatively (works with Sqlite): + migrate(migrator.add_column_default( + 'entries', + 'timestamp', + 'now()')) + .. note:: Postgres users may need to set the search-path when using a non-standard schema. This can be done as follows: diff --git a/playhouse/migrate.py b/playhouse/migrate.py index ed00ea8e4..c04b5fd4b 100644 --- a/playhouse/migrate.py +++ b/playhouse/migrate.py @@ -389,6 +389,32 @@ def drop_not_null(self, table, column): ._alter_column(self.make_context(), table, column) .literal(' DROP NOT NULL')) + @operation + def add_column_default(self, table, column, default): + if default is None: + raise ValueError('`default` must be not None/NULL.') + if callable_(default): + default = default() + # Try to handle SQL functions and string literals, otherwise pass as a + # bound value. + if isinstance(default, str) and default.endswith((')', "'")): + default = SQL(default) + + return (self + ._alter_table(self.make_context(), table) + .literal(' ALTER COLUMN ') + .sql(Entity(column)) + .literal(' SET DEFAULT ') + .sql(default)) + + @operation + def drop_column_default(self, table, column): + return (self + ._alter_table(self.make_context(), table) + .literal(' ALTER COLUMN ') + .sql(Entity(column)) + .literal(' DROP DEFAULT')) + @operation def alter_column_type(self, table, column, field, cast=None): # ALTER TABLE ALTER COLUMN @@ -866,6 +892,27 @@ def _drop_not_null(column_name, column_def): return column_def.replace('NOT NULL', '') return self._update_column(table, column, _drop_not_null) + @operation + def add_column_default(self, table, column, default): + if default is None: + raise ValueError('`default` must be not None/NULL.') + if callable_(default): + default = default() + if (isinstance(default, str) and not default.endswith((')', "'")) + and not default.isdigit()): + default = "'%s'" % default + def _add_default(column_name, column_def): + # Try to handle SQL functions and string literals, otherwise quote. + return column_def + ' DEFAULT %s' % default + return self._update_column(table, column, _add_default) + + @operation + def drop_column_default(self, table, column): + def _drop_default(column_name, column_def): + col = re.sub(r'DEFAULT\s+[\w"\'\(\)]+(\s|$)', '', column_def, re.I) + return col.strip() + return self._update_column(table, column, _drop_default) + @operation def alter_column_type(self, table, column, field, cast=None): if cast is not None: diff --git a/tests/migrations.py b/tests/migrations.py index d6f73e2be..b6ff81f58 100644 --- a/tests/migrations.py +++ b/tests/migrations.py @@ -330,6 +330,29 @@ class Meta: def test_rename_gh380_sqlite_legacy(self): self.test_rename_gh380(legacy=True) + def test_add_default_drop_default(self): + with self.database.transaction(): + migrate(self.migrator.add_column_default('person', 'first_name', + default='x')) + + p = Person.create(last_name='Last') + p_db = Person.get(Person.last_name == 'Last') + self.assertEqual(p_db.first_name, 'x') + + with self.database.transaction(): + migrate(self.migrator.drop_column_default('person', 'first_name')) + + if IS_MYSQL: + # MySQL, even though the column is NOT NULL, does not seem to be + # enforcing the constraint(?). + Person.create(last_name='Last2') + p_db = Person.get(Person.last_name == 'Last2') + self.assertEqual(p_db.first_name, '') + else: + with self.assertRaises(IntegrityError): + with self.database.transaction(): + Person.create(last_name='Last2') + def test_add_not_null(self): self._create_people()