diff --git a/CHANGELOG.md b/CHANGELOG.md index fff7a70a7..98d0d63e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,26 @@ https://github.com/coleifer/peewee/releases ## master -[View commits](https://github.com/coleifer/peewee/compare/3.9.2...master) +[View commits](https://github.com/coleifer/peewee/compare/3.9.3...master) + +## 3.9.3 + +* Added cross-database support for `NULLS FIRST/LAST` when specifying the + ordering for a query. Previously this was only supported for Postgres. Peewee + will now generate an equivalent `CASE` statement for Sqlite and MySQL. +* Added [EXCLUDED](http://docs.peewee-orm.com/en/latest/peewee/api.html#EXCLUDED) + helper for referring to the `EXCLUDED` namespace used with `INSERT...ON CONFLICT` + queries, when referencing values in the conflicting row data. +* Added helper method to the model `Metadata` class for setting the table name + at run-time. Setting the `Model._meta.table_name` directly may have appeared + to work in some situations, but could lead to subtle bugs. The new API is + `Model._meta.set_table_name()`. +* Enhanced helpers for working with Peewee interactively, [see doc](http://docs.peewee-orm.com/en/latest/peewee/interactive.html). +* Fix cache invalidation bug in `DataSet` that was originally reported on the + sqlite-web project. +* New example script implementing a [hexastore](https://github.com/coleifer/peewee/blob/master/examples/hexastore.py). + +[View commits](https://github.com/coleifer/peewee/compare/3.9.2...3.9.3) ## 3.9.1 and 3.9.2 diff --git a/docs/peewee/api.rst b/docs/peewee/api.rst index 6c0582a30..ddd04fef6 100644 --- a/docs/peewee/api.rst +++ b/docs/peewee/api.rst @@ -1246,12 +1246,16 @@ Query-builder Create a ``CAST`` expression. - .. py:method:: asc() + .. py:method:: asc([collation=None[, nulls=None]]) + :param str collation: Collation name to use for sorting. + :param str nulls: Sort nulls (FIRST or LAST). :returns: an ascending :py:class:`Ordering` object for the column. - .. py:method:: desc() + .. py:method:: desc([collation=None[, nulls=None]]) + :param str collation: Collation name to use for sorting. + :param str nulls: Sort nulls (FIRST or LAST). :returns: an descending :py:class:`Ordering` object for the column. .. py:method:: __invert__() @@ -1326,6 +1330,10 @@ Query-builder Represent ordering by a column-like object. + Postgresql supports a non-standard clause ("NULLS FIRST/LAST"). Peewee will + automatically use an equivalent ``CASE`` statement for databases that do + not support this (Sqlite / MySQL). + .. py:method:: collate([collation=None]) :param str collation: Collation name to use for sorting. @@ -1661,7 +1669,7 @@ Query-builder :param str action: Action to take when resolving conflict. :param update: A dictionary mapping column to new value. - :param preserve: A list of columns whose values should be preserved from the original INSERT. + :param preserve: A list of columns whose values should be preserved from the original INSERT. See also :py:class:`EXCLUDED`. :param where: Expression to restrict the conflict resolution. :param conflict_target: Column(s) that comprise the constraint. :param conflict_where: Expressions needed to match the constraint target if it is a partial index (index with a WHERE clause). @@ -1705,6 +1713,44 @@ Query-builder conflict resolution. Currently only supported by Postgres. +.. py:class:: EXCLUDED + + Helper object that exposes the ``EXCLUDED`` namespace that is used with + ``INSERT ... ON CONFLICT`` to reference values in the conflicting data. + This is a "magic" helper, such that one uses it by accessing attributes on + it that correspond to a particular column. + + Example: + + .. code-block:: python + + class KV(Model): + key = CharField(unique=True) + value = IntegerField() + + # Create one row. + KV.create(key='k1', value=1) + + # Demonstrate usage of EXCLUDED. + # Here we will attempt to insert a new value for a given key. If that + # key already exists, then we will update its value with the *sum* of its + # original value and the value we attempted to insert -- provided that + # the new value is larger than the original value. + query = (KV.insert(key='k1', value=10) + .on_conflict(conflict_target=[KV.key], + update={KV.value: KV.value + EXCLUDED.value}, + where=(EXCLUDED.value > KV.value))) + + # Executing the above query will result in the following data being + # present in the "kv" table: + # (key='k1', value=11) + query.execute() + + # If we attempted to execute the query *again*, then nothing would be + # updated, as the new value (10) is now less than the value in the + # original row (11). + + .. py:class:: BaseQuery() The parent class from which all other query classes are derived. While you @@ -2431,7 +2477,7 @@ Query-builder Specify the parameters for an :py:class:`OnConflict` clause to use for conflict resolution. - Example: + Examples: .. code-block:: python @@ -2456,6 +2502,36 @@ Query-builder .execute()) return userid + Example using the special :py:class:`EXCLUDED` namespace: + + .. code-block:: python + + class KV(Model): + key = CharField(unique=True) + value = IntegerField() + + # Create one row. + KV.create(key='k1', value=1) + + # Demonstrate usage of EXCLUDED. + # Here we will attempt to insert a new value for a given key. If that + # key already exists, then we will update its value with the *sum* of its + # original value and the value we attempted to insert -- provided that + # the new value is larger than the original value. + query = (KV.insert(key='k1', value=10) + .on_conflict(conflict_target=[KV.key], + update={KV.value: KV.value + EXCLUDED.value}, + where=(EXCLUDED.value > KV.value))) + + # Executing the above query will result in the following data being + # present in the "kv" table: + # (key='k1', value=11) + query.execute() + + # If we attempted to execute the query *again*, then nothing would be + # updated, as the new value (10) is now less than the value in the + # original row (11). + .. py:class:: Delete() @@ -3068,6 +3144,33 @@ Fields ``Husband.wife`` is automatically resolved and turned into a regular :py:class:`ForeignKeyField`. + .. warning:: + :py:class:`DeferredForeignKey` references are resolved when model + classes are declared and created. This means that if you declare a + :py:class:`DeferredForeignKey` to a model class that has already been + imported and created, the deferred foreign key instance will never be + resolved. For example: + + .. code-block:: python + + class User(Model): + username = TextField() + + class Tweet(Model): + # This will never actually be resolved, because the User + # model has already been declared. + user = DeferredForeignKey('user', backref='tweets') + content = TextField() + + In cases like these you should use the regular + :py:class:`ForeignKeyField` *or* you can manually resolve deferred + foreign keys like so: + + .. code-block:: python + + # Tweet.user will be resolved into a ForeignKeyField: + DeferredForeignKey.resolve(User) + .. py:class:: ManyToManyField(model[, backref=None[, through_model=None[, on_delete=None[, on_update=None]]]]) :param Model model: Model to create relationship with. @@ -3499,6 +3602,27 @@ Model Traverse the model graph and return a list of 3-tuples, consisting of ``(foreign key field, model class, is_backref)``. + .. py:method:: set_database(database) + + :param Database database: database object to bind Model to. + + Bind the model class to the given :py:class:`Database` instance. + + .. warning:: + This API should not need to be used. Instead, to change a + :py:class:`Model` database at run-time, use one of the following: + + * :py:meth:`Model.bind` + * :py:meth:`Model.bind_ctx` (bind for scope of a context manager). + * :py:meth:`Database.bind` + * :py:meth:`Database.bind_ctx` + + .. py:method:: set_table_name(table_name) + + :param str table_name: table name to bind Model to. + + Bind the model class to the given table name at run-time. + .. py:class:: SubclassAwareMetadata diff --git a/docs/peewee/models.rst b/docs/peewee/models.rst index b371712f4..fa7527e69 100644 --- a/docs/peewee/models.rst +++ b/docs/peewee/models.rst @@ -110,7 +110,8 @@ an example: In the above example, because none of the fields are initialized with ``primary_key=True``, an auto-incrementing primary key will automatically be -created and named "id". +created and named "id". Peewee uses :py:class:`AutoField` to signify an +auto-incrementing integer primary key, which implies ``primary_key=True``. There is one special type of field, :py:class:`ForeignKeyField`, which allows you to represent foreign-key relationships between models in an intuitive way: @@ -149,11 +150,11 @@ Field types table ===================== ================= ================= ================= Field Type Sqlite Postgresql MySQL ===================== ================= ================= ================= +``AutoField`` integer serial integer +``BigAutoField`` integer bigserial bigint ``IntegerField`` integer integer integer ``BigIntegerField`` integer bigint bigint ``SmallIntegerField`` integer smallint smallint -``AutoField`` integer serial integer -``BigAutoField`` integer bigserial bigint ``IdentityField`` not supported int identity not supported ``FloatField`` real real real ``DoubleField`` real double precision double precision @@ -1010,8 +1011,76 @@ You can also implement ``CHECK`` constraints at the table level: .. _non_integer_primary_keys: -Non-integer Primary Keys, Composite Keys and other Tricks ---------------------------------------------------------- +Primary Keys, Composite Keys and other Tricks +--------------------------------------------- + +The :py:class:`AutoField` is used to identify an auto-incrementing integer +primary key. If you do not specify a primary key, Peewee will automatically +create an auto-incrementing primary key named "id". + +To specify an auto-incrementing ID using a different field name, you can write: + +.. code-block:: python + + class Event(Model): + event_id = AutoField() # Event.event_id will be auto-incrementing PK. + name = CharField() + timestamp = DateTimeField(default=datetime.datetime.now) + metadata = BlobField() + +You can identify a different field as the primary key, in which case an "id" +column will not be created. In this example we will use a person's email +address as the primary key: + +.. code-block:: python + + class Person(Model): + email = CharField(primary_key=True) + name = TextField() + dob = DateField() + +.. warning:: + I frequently see people write the following, expecting an auto-incrementing + integer primary key: + + .. code-block:: python + + class MyModel(Model): + id = IntegerField(primary_key=True) + + Peewee understands the above model declaration as a model with an integer + primary key, but the value of that ID is determined by the application. To + create an auto-incrementing integer primary key, you would instead write: + + .. code-block:: python + + class MyModel(Model): + id = AutoField() # primary_key=True is implied. + +Composite primary keys can be declared using :py:class:`CompositeKey`. Note +that doing this may cause issues with :py:class:`ForeignKeyField`, as Peewee +does not support the concept of a "composite foreign-key". As such, I've found +it only advisable to use composite primary keys in a handful of situations, +such as trivial many-to-many junction tables: + +.. code-block:: python + + class Image(Model): + filename = TextField() + mimetype = CharField() + + class Tag(Model): + label = CharField() + + class ImageTag(Model): # Many-to-many relationship. + image = ForeignKeyField(Image) + tag = ForeignKeyField(Tag) + + class Meta: + primary_key = CompositeKey('image', 'tag') + +In the extremely rare case you wish to declare a model with *no* primary key, +you can specify ``primary_key = False`` in the model ``Meta`` options. Non-integer primary keys ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index 0566a1251..bfca0d80b 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -3072,6 +3072,25 @@ including :ref:`dataset` and :ref:`pwiz`. Generate models for the tables in the given database. For an example of how to use this function, see the section :ref:`interactive`. + Example: + + .. code-block:: pycon + + >>> from peewee import * + >>> from playhouse.reflection import generate_models + >>> db = PostgresqlDatabase('my_app') + >>> models = generate_models(db) + >>> list(models.keys()) + ['account', 'customer', 'order', 'orderitem', 'product'] + + >>> globals().update(models) # Inject models into namespace. + >>> for cust in customer.select(): # Query using generated model. + ... print(cust.name) + ... + + Huey Kitty + Mickey Dog + .. py:function:: print_model(model) :param Model model: model class to print @@ -3081,6 +3100,34 @@ including :ref:`dataset` and :ref:`pwiz`. interactive use. Currently this prints the table name, and all fields along with their data-types. The :ref:`interactive` section contains an example. + Example output: + + .. code-block:: pycon + + >>> from playhouse.reflection import print_model + >>> print_model(User) + user + id AUTO PK + email TEXT + name TEXT + dob DATE + + index(es) + email UNIQUE + + >>> print_model(Tweet) + tweet + id AUTO PK + user INT FK: User.id + title TEXT + content TEXT + timestamp DATETIME + is_published BOOL + + index(es) + user_id + is_published, timestamp + .. py:function:: print_table_sql(model) :param Model model: model to print @@ -3088,7 +3135,32 @@ including :ref:`dataset` and :ref:`pwiz`. Prints the SQL ``CREATE TABLE`` for the given model class, which may be useful for debugging or interactive use. See the :ref:`interactive` section - for example usage. + for example usage. Note that indexes and constraints are not included in + the output of this function. + + Example output: + + .. code-block:: pycon + + >>> from playhouse.reflection import print_table_sql + >>> print_table_sql(User) + CREATE TABLE IF NOT EXISTS "user" ( + "id" INTEGER NOT NULL PRIMARY KEY, + "email" TEXT NOT NULL, + "name" TEXT NOT NULL, + "dob" DATE NOT NULL + ) + + >>> print_table_sql(Tweet) + CREATE TABLE IF NOT EXISTS "tweet" ( + "id" INTEGER NOT NULL PRIMARY KEY, + "user_id" INTEGER NOT NULL, + "title" TEXT NOT NULL, + "content" TEXT NOT NULL, + "timestamp" DATETIME NOT NULL, + "is_published" INTEGER NOT NULL, + FOREIGN KEY ("user_id") REFERENCES "user" ("id") + ) .. py:class:: Introspector(metadata[, schema=None]) diff --git a/docs/peewee/querying.rst b/docs/peewee/querying.rst index 0ddca0e28..f7d12e3fe 100644 --- a/docs/peewee/querying.rst +++ b/docs/peewee/querying.rst @@ -433,6 +433,39 @@ column will be updated, and no duplicate rows will be created. The main difference between MySQL and Postgresql/SQLite is that Postgresql and SQLite require that you specify a ``conflict_target``. +Here is a more advanced (if contrived) example using the :py:class:`EXCLUDED` +namespace. The :py:class:`EXCLUDED` helper allows us to reference values in the +conflicting data. For our example, we'll assume a simple table mapping a unique +key (string) to a value (integer): + +.. code-block:: python + + class KV(Model): + key = CharField(unique=True) + value = IntegerField() + + # Create one row. + KV.create(key='k1', value=1) + + # Demonstrate usage of EXCLUDED. + # Here we will attempt to insert a new value for a given key. If that + # key already exists, then we will update its value with the *sum* of its + # original value and the value we attempted to insert -- provided that + # the new value is larger than the original value. + query = (KV.insert(key='k1', value=10) + .on_conflict(conflict_target=[KV.key], + update={KV.value: KV.value + EXCLUDED.value}, + where=(EXCLUDED.value > KV.value))) + + # Executing the above query will result in the following data being + # present in the "kv" table: + # (key='k1', value=11) + query.execute() + + # If we attempted to execute the query *again*, then nothing would be + # updated, as the new value (10) is now less than the value in the + # original row (11). + For more information, see :py:meth:`Insert.on_conflict` and :py:class:`OnConflict`. diff --git a/examples/hexastore.py b/examples/hexastore.py new file mode 100644 index 000000000..824a306b2 --- /dev/null +++ b/examples/hexastore.py @@ -0,0 +1,168 @@ +try: + from functools import reduce +except ImportError: + pass +import operator + +from peewee import * + + +class Hexastore(object): + def __init__(self, database=':memory:', **options): + if isinstance(database, str): + self.db = SqliteDatabase(database, **options) + elif isinstance(database, Database): + self.db = database + else: + raise ValueError('Expected database filename or a Database ' + 'instance. Got: %s' % repr(database)) + + self.v = _VariableFactory() + self.G = self.get_model() + + def get_model(self): + class Graph(Model): + subj = TextField() + pred = TextField() + obj = TextField() + class Meta: + database = self.db + indexes = ( + (('pred', 'obj'), False), + (('obj', 'subj'), False), + ) + primary_key = CompositeKey('subj', 'pred', 'obj') + + self.db.create_tables([Graph]) + return Graph + + def store(self, s, p, o): + self.G.create(subj=s, pred=p, obj=o) + + def store_many(self, items): + fields = [self.G.subj, self.G.pred, self.G.obj] + self.G.insert_many(items, fields=fields).execute() + + def delete(self, s, p, o): + return (self.G.delete() + .where(self.G.subj == s, self.G.pred == p, self.G.obj == o) + .execute()) + + def query(self, s=None, p=None, o=None): + fields = (self.G.subj, self.G.pred, self.G.obj) + expressions = [(f == v) for f, v in zip(fields, (s, p, o)) + if v is not None] + return self.G.select().where(*expressions) + + def search(self, *conditions): + accum = [] + binds = {} + variables = set() + fields = {'s': 'subj', 'p': 'pred', 'o': 'obj'} + + for i, condition in enumerate(conditions): + if isinstance(condition, dict): + condition = (condition['s'], condition['p'], condition['o']) + + GA = self.G.alias('g%s' % i) + for part, val in zip('spo', condition): + if isinstance(val, Variable): + binds.setdefault(val, []) + binds[val].append(getattr(GA, fields[part])) + variables.add(val) + else: + accum.append(getattr(GA, fields[part]) == val) + + selection = [] + sources = set() + + for var, fields in binds.items(): + selection.append(fields[0].alias(var.name)) + pairwise = [(fields[i - 1] == fields[i]) + for i in range(1, len(fields))] + if pairwise: + accum.append(reduce(operator.and_, pairwise)) + sources.update([field.source for field in fields]) + + return (self.G + .select(*selection) + .from_(*list(sources)) + .where(*accum) + .dicts()) + + +class _VariableFactory(object): + def __getattr__(self, name): + return Variable(name) + __call__ = __getattr__ + +class Variable(object): + __slots__ = ('name',) + + def __init__(self, name): + self.name = name + + def __hash__(self): + return hash(self.name) + + def __repr__(self): + return '' % self.name + + +if __name__ == '__main__': + h = Hexastore() + + data = ( + ('charlie', 'likes', 'beanie'), + ('charlie', 'likes', 'huey'), + ('charlie', 'likes', 'mickey'), + ('charlie', 'likes', 'scout'), + ('charlie', 'likes', 'zaizee'), + + ('huey', 'likes', 'charlie'), + ('huey', 'likes', 'scout'), + ('huey', 'likes', 'zaizee'), + + ('mickey', 'likes', 'beanie'), + ('mickey', 'likes', 'charlie'), + ('mickey', 'likes', 'scout'), + + ('zaizee', 'likes', 'beanie'), + ('zaizee', 'likes', 'charlie'), + ('zaizee', 'likes', 'scout'), + + ('charlie', 'lives', 'topeka'), + ('beanie', 'lives', 'heaven'), + ('huey', 'lives', 'topeka'), + ('mickey', 'lives', 'topeka'), + ('scout', 'lives', 'heaven'), + ('zaizee', 'lives', 'lawrence'), + ) + h.store_many(data) + print('added %s items to store' % len(data)) + + print('\nwho lives in topeka?') + for obj in h.query(p='lives', o='topeka'): + print(obj.subj) + + print('\nmy friends in heaven?') + X = h.v.x + results = h.search(('charlie', 'likes', X), + (X, 'lives', 'heaven')) + for result in results: + print(result['x']) + + print('\nmutual friends?') + X = h.v.x + Y = h.v.y + results = h.search((X, 'likes', Y), (Y, 'likes', X)) + for result in results: + print(result['x'], ' <-> ', result['y']) + + print('\nliked by both charlie, huey and mickey?') + X = h.v.x + results = h.search(('charlie', 'likes', X), + ('huey', 'likes', X), + ('mickey', 'likes', X)) + for result in results: + print(result['x']) diff --git a/peewee.py b/peewee.py index 2fb246d50..c1ea50205 100644 --- a/peewee.py +++ b/peewee.py @@ -61,7 +61,7 @@ mysql = None -__version__ = '3.9.2' +__version__ = '3.9.3' __all__ = [ 'AsIs', 'AutoField', @@ -94,6 +94,7 @@ 'DoesNotExist', 'DoubleField', 'DQ', + 'EXCLUDED', 'Field', 'FixedCharField', 'FloatField', @@ -455,6 +456,8 @@ def savepoint(self): class AliasManager(object): + __slots__ = ('_counter', '_current_index', '_mapping') + def __init__(self): # A list of dictionaries containing mappings at various depths. self._counter = 0 @@ -1313,15 +1316,30 @@ def __init__(self, node, direction, collation=None, nulls=None): self.direction = direction self.collation = collation self.nulls = nulls + if nulls and nulls.lower() not in ('first', 'last'): + raise ValueError('Ordering nulls= parameter must be "first" or ' + '"last", got: %s' % nulls) def collate(self, collation=None): return Ordering(self.node, self.direction, collation) + def _null_ordering_case(self, nulls): + if nulls.lower() == 'last': + ifnull, notnull = 1, 0 + elif nulls.lower() == 'first': + ifnull, notnull = 0, 1 + else: + raise ValueError('unsupported value for nulls= ordering.') + return Case(None, ((self.node.is_null(), ifnull),), notnull) + def __sql__(self, ctx): + if self.nulls and not ctx.state.nulls_ordering: + ctx.sql(self._null_ordering_case(self.nulls)).literal(', ') + ctx.sql(self.node).literal(' %s' % self.direction) if self.collation: ctx.literal(' COLLATE %s' % self.collation) - if self.nulls: + if self.nulls and ctx.state.nulls_ordering: ctx.literal(' NULLS %s' % self.nulls) return ctx @@ -1594,6 +1612,27 @@ def EnclosedNodeList(nodes): return NodeList(nodes, ', ', True) +class _Namespace(Node): + __slots__ = ('_name',) + def __init__(self, name): + self._name = name + def __getattr__(self, attr): + return NamespaceAttribute(self, attr) + __getitem__ = __getattr__ + +class NamespaceAttribute(ColumnBase): + def __init__(self, namespace, attribute): + self._namespace = namespace + self._attribute = attribute + + def __sql__(self, ctx): + return (ctx + .literal(self._namespace._name + '.') + .sql(Entity(self._attribute))) + +EXCLUDED = _Namespace('EXCLUDED') + + class DQ(ColumnBase): def __init__(self, **query): super(DQ, self).__init__() @@ -2700,6 +2739,7 @@ class Database(_callable_context_manager): for_update = False index_schema_prefix = False limit_max = None + nulls_ordering = False returning_clause = False safe_create_index = True safe_drop_index = True @@ -2859,6 +2899,7 @@ def get_context_options(self): 'for_update': self.for_update, 'index_schema_prefix': self.index_schema_prefix, 'limit_max': self.limit_max, + 'nulls_ordering': self.nulls_ordering, } def get_sql_context(self, **context_options): @@ -3410,6 +3451,7 @@ class PostgresqlDatabase(Database): commit_select = True compound_select_parentheses = CSQ_PARENTHESES_ALWAYS for_update = True + nulls_ordering = True returning_clause = True safe_create_index = False sequences = True @@ -4769,7 +4811,9 @@ def __init__(self, rel_model_name, **kwargs): self.field_kwargs = kwargs self.rel_model_name = rel_model_name.lower() DeferredForeignKey._unresolved.add(self) - super(DeferredForeignKey, self).__init__() + super(DeferredForeignKey, self).__init__( + column_name=kwargs.get('column_name'), + null=kwargs.get('null')) __hash__ = object.__hash__ @@ -4782,7 +4826,8 @@ def set_model(self, rel_model): @staticmethod def resolve(model_cls): - unresolved = list(DeferredForeignKey._unresolved) + unresolved = sorted(DeferredForeignKey._unresolved, + key=operator.attrgetter('_order')) for dr in unresolved: if dr.rel_model_name == model_cls.__name__.lower(): dr.set_model(model_cls) @@ -4912,12 +4957,11 @@ class Meta: ((lhs._meta.name, rhs._meta.name), True),) + params = {'on_delete': self._on_delete, 'on_update': self._on_update} attrs = { - lhs._meta.name: ForeignKeyField(lhs, on_delete=self._on_delete, - on_update=self._on_update), - rhs._meta.name: ForeignKeyField(rhs, on_delete=self._on_delete, - on_update=self._on_update)} - attrs['Meta'] = Meta + lhs._meta.name: ForeignKeyField(lhs, **params), + rhs._meta.name: ForeignKeyField(rhs, **params), + 'Meta': Meta} klass_name = '%s%sThrough' % (lhs.__name__, rhs.__name__) return type(klass_name, (Model,), attrs) @@ -5512,6 +5556,10 @@ def set_database(self, database): self.model._schema._database = database del self.table + def set_table_name(self, table_name): + self.table_name = table_name + del self.table + class SubclassAwareMetadata(Metadata): models = [] diff --git a/playhouse/dataset.py b/playhouse/dataset.py index 029151ba4..27f8189bb 100644 --- a/playhouse/dataset.py +++ b/playhouse/dataset.py @@ -98,6 +98,7 @@ def update_cache(self, table=None): dependencies.extend(self.get_table_dependencies(table)) else: dependencies = None # Update all tables. + self._models = {} updated = self._introspector.generate_models( skip_invalid=True, table_names=dependencies, diff --git a/playhouse/reflection.py b/playhouse/reflection.py index dd6de09b6..8f83a8ede 100644 --- a/playhouse/reflection.py +++ b/playhouse/reflection.py @@ -748,9 +748,13 @@ def print_model(model, indexes=True, inline_indexes=False): field.rel_field.name)) print(''.join(parts)) - if indexes and model._meta.indexes: + if indexes: + index_list = model._meta.fields_to_index() + if not index_list: + return + print('\nindex(es)') - for index in model._meta.fields_to_index(): + for index in index_list: parts = [' '] ctx = model._meta.database.get_sql_context() with ctx.scope_values(param='%s', quote='""'): diff --git a/tests/base.py b/tests/base.py index 1cf28d1c4..6ae205b6e 100644 --- a/tests/base.py +++ b/tests/base.py @@ -228,22 +228,17 @@ def requires_models(*models): def decorator(method): @wraps(method) def inner(self): - _db_mapping = {} - for model in models: - _db_mapping[model] = model._meta.database - model._meta.set_database(self.database) - self.database.drop_tables(models, safe=True) - self.database.create_tables(models) + with self.database.bind_ctx(models, False, False): + self.database.drop_tables(models, safe=True) + self.database.create_tables(models) - try: - method(self) - finally: try: - self.database.drop_tables(models) - except: - pass - for model in models: - model._meta.set_database(_db_mapping[model]) + method(self) + finally: + try: + self.database.drop_tables(models) + except: + pass return inner return decorator diff --git a/tests/database.py b/tests/database.py index bbdf65ae4..8372ea8f0 100644 --- a/tests/database.py +++ b/tests/database.py @@ -147,6 +147,24 @@ def test_connection_state(self): conn = self.database.connection() self.assertFalse(self.database.is_closed()) + def test_db_context_manager(self): + self.database.close() + self.assertTrue(self.database.is_closed()) + + with self.database: + self.assertFalse(self.database.is_closed()) + + self.assertTrue(self.database.is_closed()) + self.database.connect() + self.assertFalse(self.database.is_closed()) + + # Enter context with an already-open db. + with self.database: + self.assertFalse(self.database.is_closed()) + + # Closed after exit. + self.assertTrue(self.database.is_closed()) + def test_connection_initialization(self): state = {'count': 0} class TestDatabase(SqliteDatabase): diff --git a/tests/dataset.py b/tests/dataset.py index e4131bcd9..39952166d 100644 --- a/tests/dataset.py +++ b/tests/dataset.py @@ -125,6 +125,15 @@ def test_update_cache(self): self.assertEqual(sorted(Foo.columns), ['data', 'id']) self.assertTrue('foo' in self.dataset._models) + self.dataset._models['foo'].drop_table() + self.dataset.update_cache() + self.assertTrue('foo' not in self.database.get_tables()) + + # This will create the table again. + Foo = self.dataset['foo'] + self.assertTrue('foo' in self.database.get_tables()) + self.assertEqual(Foo.columns, ['id']) + def assertQuery(self, query, expected, sort_key='id'): key = operator.itemgetter(sort_key) self.assertEqual( diff --git a/tests/fields.py b/tests/fields.py index 7720e7b1e..c8d4fe909 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -371,6 +371,68 @@ def test_deferred_foreign_key(self): self.assertEqual(m2_db.m1.name, 'm1') +class TestDeferredForeignKeyResolution(ModelTestCase): + def test_unresolved_deferred_fk(self): + class Photo(Model): + album = DeferredForeignKey('Album', column_name='id_album') + class Meta: + database = get_in_memory_db() + self.assertSQL(Photo.select(), ( + 'SELECT "t1"."id", "t1"."id_album" FROM "photo" AS "t1"'), []) + + def test_deferred_foreign_key_resolution(self): + class Base(Model): + class Meta: + database = get_in_memory_db() + + class Photo(Base): + album = DeferredForeignKey('Album', column_name='id_album', + null=False, backref='pictures') + alt_album = DeferredForeignKey('Album', column_name='id_Alt_album', + field='alt_id', backref='alt_pix', + null=True) + + class Album(Base): + name = TextField() + alt_id = IntegerField(column_name='_Alt_id') + + self.assertTrue(Photo.album.rel_model is Album) + self.assertTrue(Photo.album.rel_field is Album.id) + self.assertEqual(Photo.album.column_name, 'id_album') + self.assertFalse(Photo.album.null) + + self.assertTrue(Photo.alt_album.rel_model is Album) + self.assertTrue(Photo.alt_album.rel_field is Album.alt_id) + self.assertEqual(Photo.alt_album.column_name, 'id_Alt_album') + self.assertTrue(Photo.alt_album.null) + + self.assertSQL(Photo._schema._create_table(), ( + 'CREATE TABLE IF NOT EXISTS "photo" (' + '"id" INTEGER NOT NULL PRIMARY KEY, ' + '"id_album" INTEGER NOT NULL, ' + '"id_Alt_album" INTEGER)'), []) + + self.assertSQL(Photo._schema._create_foreign_key(Photo.album), ( + 'ALTER TABLE "photo" ADD CONSTRAINT "fk_photo_id_album_refs_album"' + ' FOREIGN KEY ("id_album") REFERENCES "album" ("id")')) + self.assertSQL(Photo._schema._create_foreign_key(Photo.alt_album), ( + 'ALTER TABLE "photo" ADD CONSTRAINT ' + '"fk_photo_id_Alt_album_refs_album"' + ' FOREIGN KEY ("id_Alt_album") REFERENCES "album" ("_Alt_id")')) + + self.assertSQL(Photo.select(), ( + 'SELECT "t1"."id", "t1"."id_album", "t1"."id_Alt_album" ' + 'FROM "photo" AS "t1"'), []) + + a = Album(id=3, alt_id=4) + self.assertSQL(a.pictures, ( + 'SELECT "t1"."id", "t1"."id_album", "t1"."id_Alt_album" ' + 'FROM "photo" AS "t1" WHERE ("t1"."id_album" = ?)'), [3]) + self.assertSQL(a.alt_pix, ( + 'SELECT "t1"."id", "t1"."id_album", "t1"."id_Alt_album" ' + 'FROM "photo" AS "t1" WHERE ("t1"."id_Alt_album" = ?)'), [4]) + + class Composite(TestModel): first = CharField() last = CharField() @@ -488,6 +550,9 @@ def test_bit_field(self): query = Bits.select().where(Bits.is_favorite).order_by(Bits.id) self.assertEqual([x.id for x in query], [b2.id, b3.id]) + query = Bits.select().where(~Bits.is_favorite).order_by(Bits.id) + self.assertEqual([x.id for x in query], [b1.id]) + # "&" operator does bitwise and for BitField. query = Bits.select().where((Bits.flags & 1) == 1).order_by(Bits.id) self.assertEqual([x.id for x in query], [b1.id, b3.id]) diff --git a/tests/kv.py b/tests/kv.py index 534118e0a..f2ccf6a68 100644 --- a/tests/kv.py +++ b/tests/kv.py @@ -2,9 +2,7 @@ from playhouse.kv import KeyValue from .base import DatabaseTestCase -from .base import IS_POSTGRESQL from .base import db -from .base import skip_if class TestKeyValue(DatabaseTestCase): diff --git a/tests/models.py b/tests/models.py index 50600b545..202457fdb 100644 --- a/tests/models.py +++ b/tests/models.py @@ -2969,6 +2969,11 @@ def requires_upsert(m): return skip_unless(IS_SQLITE_24 or IS_POSTGRESQL, 'requires upsert')(m) +class KV(TestModel): + key = CharField(unique=True) + value = IntegerField() + + class PGOnConflictTests(OnConflictTests): @requires_upsert def test_update(self): @@ -3065,6 +3070,38 @@ def test_update_where_clause(self): self.assertEqual(obj.a, 'foo') self.assertEqual(obj.b, 3) + @requires_upsert + @requires_models(Emp) # Has unique on first/last, unique on empno. + def test_conflict_update_excluded(self): + e1 = Emp.create(first='huey', last='c', empno='10') + e2 = Emp.create(first='zaizee', last='c', empno='20') + + res = (Emp.insert(first='huey', last='c', empno='30') + .on_conflict(conflict_target=(Emp.first, Emp.last), + update={Emp.empno: Emp.empno + EXCLUDED.empno}, + where=(EXCLUDED.empno != Emp.empno)) + .execute()) + + data = sorted(Emp.select(Emp.first, Emp.last, Emp.empno).tuples()) + self.assertEqual(data, [('huey', 'c', '1030'), ('zaizee', 'c', '20')]) + + @requires_upsert + @requires_models(KV) + def test_conflict_update_excluded2(self): + KV.create(key='k1', value=1) + + query = (KV.insert(key='k1', value=10) + .on_conflict(conflict_target=[KV.key], + update={KV.value: KV.value + EXCLUDED.value}, + where=(EXCLUDED.value > KV.value))) + query.execute() + self.assertEqual(KV.select(KV.key, KV.value).tuples()[:], [('k1', 11)]) + + # Running it again will have no effect this time, since the new value + # (10) is not greater than the pre-existing row value (11). + query.execute() + self.assertEqual(KV.select(KV.key, KV.value).tuples()[:], [('k1', 11)]) + @requires_upsert @requires_models(UKVP) def test_conflict_target_constraint_where(self): @@ -3932,3 +3969,37 @@ def test_execute_query(self): query = User.select().order_by(User.username.desc()) cursor = self.database.execute(query) self.assertEqual([row[1] for row in cursor], ['zaizee', 'huey']) + + +class Datum(TestModel): + key = TextField() + value = IntegerField(null=True) + +class TestNullOrdering(ModelTestCase): + requires = [Datum] + + def test_null_ordering(self): + values = [('k1', 1), ('ka', None), ('k2', 2), ('kb', None)] + Datum.insert_many(values, fields=[Datum.key, Datum.value]).execute() + + def assertOrder(ordering, expected): + query = Datum.select().order_by(*ordering) + self.assertEqual([d.key for d in query], expected) + + # Ascending order. + nulls_last = (Datum.value.asc(nulls='last'), Datum.key) + assertOrder(nulls_last, ['k1', 'k2', 'ka', 'kb']) + + nulls_first = (Datum.value.asc(nulls='first'), Datum.key) + assertOrder(nulls_first, ['ka', 'kb', 'k1', 'k2']) + + # Descending order. + nulls_last = (Datum.value.desc(nulls='last'), Datum.key) + assertOrder(nulls_last, ['k2', 'k1', 'ka', 'kb']) + + nulls_first = (Datum.value.desc(nulls='first'), Datum.key) + assertOrder(nulls_first, ['ka', 'kb', 'k2', 'k1']) + + # Invalid values. + self.assertRaises(ValueError, Datum.value.desc, nulls='bar') + self.assertRaises(ValueError, Datum.value.asc, nulls='foo') diff --git a/tests/pool.py b/tests/pool.py index 163b27a96..886381fda 100644 --- a/tests/pool.py +++ b/tests/pool.py @@ -8,8 +8,13 @@ from peewee import _transaction from playhouse.pool import * +from .base import BACKEND from .base import BaseTestCase +from .base import IS_MYSQL +from .base import IS_POSTGRESQL +from .base import IS_SQLITE from .base import ModelTestCase +from .base import db_loader from .base_models import Register @@ -355,3 +360,122 @@ def test_bad_connection(self): pass self.database.close() self.database.connect() + + +class TestPooledDatabaseIntegration(ModelTestCase): + requires = [Register] + + def setUp(self): + params = {} + if IS_MYSQL: + db_class = PooledMySQLDatabase + elif IS_POSTGRESQL: + db_class = PooledPostgresqlDatabase + else: + db_class = PooledSqliteDatabase + params['check_same_thread'] = False + self.database = db_loader(BACKEND, db_class=db_class, **params) + super(TestPooledDatabaseIntegration, self).setUp() + + def assertConnections(self, expected): + available = len(self.database._connections) + in_use = len(self.database._in_use) + self.assertEqual(available + in_use, expected, + 'expected %s, got: %s available, %s in use' + % (expected, available, in_use)) + + def test_pooled_database_integration(self): + # Connection should be open from the setup method. + self.assertFalse(self.database.is_closed()) + self.assertConnections(1) + self.assertTrue(self.database.close()) + self.assertTrue(self.database.is_closed()) + self.assertConnections(1) + + signal = threading.Event() + def connect(): + self.assertTrue(self.database.is_closed()) + self.assertTrue(self.database.connect()) + self.assertFalse(self.database.is_closed()) + signal.wait() + self.assertTrue(self.database.close()) + self.assertTrue(self.database.is_closed()) + + # Open connections in 4 separate threads. + threads = [threading.Thread(target=connect) for _ in range(4)] + for t in threads: t.start() + + while len(self.database._in_use) < 4: + time.sleep(.005) + + # Close connections in all 4 threads. + signal.set() + for t in threads: t.join() + + # Verify that there are 4 connections available in the pool. + self.assertConnections(4) + self.assertEqual(len(self.database._connections), 4) # Available. + self.assertEqual(len(self.database._in_use), 0) + + # Verify state of the main thread, just a sanity check. + self.assertTrue(self.database.is_closed()) + + # Opening a connection will pull from the pool. + self.assertTrue(self.database.connect()) + self.assertFalse(self.database.connect(reuse_if_open=True)) + self.assertConnections(4) + self.assertEqual(len(self.database._in_use), 1) + + # Calling close_all() closes everything, including calling thread. + self.database.close_all() + self.assertConnections(0) + self.assertTrue(self.database.is_closed()) + + def test_pool_with_models(self): + self.database.close() + signal = threading.Event() + + def create_obj(i): + with self.database.connection_context(): + with self.database.atomic(): + Register.create(value=i) + signal.wait() + + # Create 4 objects, one in each thread. The INSERT will be wrapped in a + # transaction, and after COMMIT (but while the conn is still open), we + # will wait for the signal that all objects were created. This ensures + # that all our connections are open concurrently. + threads = [threading.Thread(target=create_obj, args=(i,)) + for i in range(4)] + for t in threads: t.start() + + # Explicitly connect, as the connection is required to verify that all + # the objects are present (and that its safe to set the signal). + self.assertTrue(self.database.connect()) + while Register.select().count() != 4: + time.sleep(0.005) + + # Signal threads that they can exit now and ensure all exited. + signal.set() + for t in threads: t.join() + + # Close connection from main thread as well. + self.database.close() + + self.assertConnections(5) + self.assertEqual(len(self.database._in_use), 0) + + # Cycle through the available connections, running a query on each, and + # then manually closing it. + for i in range(5): + self.assertTrue(self.database.is_closed()) + self.assertTrue(self.database.connect()) + + # Sanity check to verify objects are created. + query = Register.select().order_by(Register.value) + self.assertEqual([r.value for r in query], [0, 1, 2, 3]) + self.database.manual_close() + self.assertConnections(4 - i) + + self.assertConnections(0) + self.assertEqual(len(self.database._in_use), 0) diff --git a/tests/schema.py b/tests/schema.py index 9378ae6d8..d8edd55f0 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -1,6 +1,7 @@ from peewee import * from peewee import NodeList +from .base import BaseTestCase from .base import get_in_memory_db from .base import ModelDatabaseTestCase from .base import TestModel @@ -603,3 +604,20 @@ class Meta: '"id" INT GENERATED BY DEFAULT AS IDENTITY NOT NULL PRIMARY KEY, ' '"data" TEXT NOT NULL)'), ]) + + +class TestModelSetTableName(BaseTestCase): + def test_set_table_name(self): + class Foo(TestModel): + pass + + self.assertEqual(Foo._meta.table_name, 'foo') + self.assertEqual(Foo._meta.table.__name__, 'foo') + + # Writing the attribute directly does not update the cached Table name. + Foo._meta.table_name = 'foo2' + self.assertEqual(Foo._meta.table.__name__, 'foo') + + # Use the helper-method. + Foo._meta.set_table_name('foo3') + self.assertEqual(Foo._meta.table.__name__, 'foo3') diff --git a/tests/sql.py b/tests/sql.py index c0a268a5c..b094e859f 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -1543,6 +1543,19 @@ def test_conflict_resolution_required(self): with self.assertRaisesCtx(ValueError): self.database.get_sql_context().parse(query) + def test_conflict_update_excluded(self): + KV = Table('kv', ('key', 'value', 'extra'), _database=self.database) + + query = (KV.insert(key='k1', value='v1', extra=1) + .on_conflict(conflict_target=(KV.key, KV.value), + update={KV.extra: EXCLUDED.extra + 2}, + where=(EXCLUDED.extra < KV.extra))) + self.assertSQL(query, ( + 'INSERT INTO "kv" ("extra", "key", "value") VALUES (?, ?, ?) ' + 'ON CONFLICT ("key", "value") DO UPDATE ' + 'SET "extra" = (EXCLUDED."extra" + ?) ' + 'WHERE (EXCLUDED."extra" < "kv"."extra")'), [1, 'k1', 'v1', 2]) + def test_conflict_target_or_constraint(self): KV = Table('kv', ('key', 'value', 'extra'), _database=self.database)