diff --git a/CHANGELOG.md b/CHANGELOG.md index 740ebe2fa..d79e410fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ https://github.com/coleifer/peewee/releases ## master * Add bitwise and other helper methods to `BigBitField`, #2802. +* Add `add_column_default` and `drop_column_default` migrator methods for + specifying a server-side default value, #2803. * The new `star` attribute was causing issues for users who had a field named star on their models. This attribute is now renamed to `__star__`. #2796. diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index 168d5c8e4..02023dbac 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -3116,6 +3116,31 @@ Adding or dropping table constraints: # Add a UNIQUE constraint on the first and last names. migrate(migrator.add_unique('person', 'first_name', 'last_name')) +Adding or dropping a database-level default value for a column: + +.. code-block:: python + + # Add a default value for a status column. + migrate(migrator.add_column_default( + 'entries', + 'status', + 'draft')) + + # Remove the default. + migrate(migrator.drop_column_default('entries', 'status')) + + # Use a function for the default value (does not work with Sqlite): + migrate(migrator.add_column_default( + 'entries', + 'timestamp', + fn.now())) + + # Or alternatively (works with Sqlite): + migrate(migrator.add_column_default( + 'entries', + 'timestamp', + 'now()')) + .. note:: Postgres users may need to set the search-path when using a non-standard schema. This can be done as follows: diff --git a/playhouse/migrate.py b/playhouse/migrate.py index ed00ea8e4..c04b5fd4b 100644 --- a/playhouse/migrate.py +++ b/playhouse/migrate.py @@ -389,6 +389,32 @@ def drop_not_null(self, table, column): ._alter_column(self.make_context(), table, column) .literal(' DROP NOT NULL')) + @operation + def add_column_default(self, table, column, default): + if default is None: + raise ValueError('`default` must be not None/NULL.') + if callable_(default): + default = default() + # Try to handle SQL functions and string literals, otherwise pass as a + # bound value. + if isinstance(default, str) and default.endswith((')', "'")): + default = SQL(default) + + return (self + ._alter_table(self.make_context(), table) + .literal(' ALTER COLUMN ') + .sql(Entity(column)) + .literal(' SET DEFAULT ') + .sql(default)) + + @operation + def drop_column_default(self, table, column): + return (self + ._alter_table(self.make_context(), table) + .literal(' ALTER COLUMN ') + .sql(Entity(column)) + .literal(' DROP DEFAULT')) + @operation def alter_column_type(self, table, column, field, cast=None): # ALTER TABLE