diff --git a/CHANGELOG.md b/CHANGELOG.md index d3f649ef9..817c83719 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,15 @@ releases, visit GitHub: https://github.com/coleifer/peewee/releases +## 3.0.20 + +* Include `schema` (if specified) when checking for table-existence. +* Correct placement of ORDER BY / LIMIT clauses in compound select queries. +* Fix bug in back-reference lookups when using `filter()` API. +* Fix bug in SQL generation for ON CONFLICT queries with Postgres, #1512. + +[View commits](https://github.com/coleifer/peewee/compare/3.0.19...3.0.20) + ## 3.0.19 * Support for more types of mappings in `insert_many()`, refs #1495. diff --git a/peewee.py b/peewee.py index 2254302ec..9aab0bb8b 100644 --- a/peewee.py +++ b/peewee.py @@ -54,7 +54,7 @@ mysql = None -__version__ = '3.0.19' +__version__ = '3.0.20' __all__ = [ 'AsIs', 'AutoField', @@ -1396,6 +1396,12 @@ def clone(self): Tuple = lambda *a: EnclosedNodeList(a) +class QualifiedNames(WrappedNode): + def __sql__(self, ctx): + with ctx.scope_column(): + return ctx.sql(self.node) + + class OnConflict(Node): def __init__(self, action=None, update=None, preserve=None, where=None, conflict_target=None): @@ -1752,8 +1758,9 @@ def __sql__(self, ctx): subquery=False): ctx.sql(self.rhs) - # Apply ORDER BY, LIMIT, OFFSET. - self._apply_ordering(ctx) + # Apply ORDER BY, LIMIT, OFFSET. + self._apply_ordering(ctx) + return self.apply_alias(ctx) @@ -2442,6 +2449,9 @@ def close(self): if self.deferred: raise Exception('Error, database must be initialized before ' 'opening a connection.') + if self.in_transaction(): + raise OperationalError('Attempting to close database while ' + 'transaction is open.') is_open = not self._state.closed try: if is_open: @@ -3058,6 +3068,8 @@ def conflict_update(self, on_conflict): if not isinstance(v, Node): converter = k.db_value if isinstance(k, Field) else None v = Value(v, converter=converter, unpack=False) + else: + v = QualifiedNames(v) updates.append(NodeList((ensure_entity(k), SQL('='), v))) parts = [SQL('ON CONFLICT'), @@ -3065,7 +3077,7 @@ def conflict_update(self, on_conflict): SQL('DO UPDATE SET'), CommaNodeList(updates)] if on_conflict._where: - parts.extend((SQL('WHERE'), on_conflict._where)) + parts.extend((SQL('WHERE'), QualifiedNames(on_conflict._where))) return NodeList(parts) @@ -3515,12 +3527,13 @@ def __set__(self, instance, obj): class BackrefAccessor(object): def __init__(self, field): self.field = field - self.model = field.model + self.model = field.rel_model + self.rel_model = field.model def __get__(self, instance, instance_type=None): if instance is not None: dest = self.field.rel_field.name - return (self.model + return (self.rel_model .select() .where(self.field == getattr(instance, dest))) return self @@ -5266,7 +5279,7 @@ def bind_ctx(cls, database, bind_refs=True, bind_backrefs=True): @classmethod def table_exists(cls): - return cls._meta.database.table_exists(cls._meta.table) + return cls._meta.database.table_exists(cls._meta.table, cls._meta.schema) @classmethod def create_table(cls, safe=True, **options): @@ -5721,8 +5734,8 @@ def filter(self, *args, **kwargs): if isinstance(field, ForeignKeyField): lm, rm = field.model, field.rel_model field_obj = field - elif isinstance(field, ReverseRelationDescriptor): - lm, rm = field.field.rel_model, field.rel_model + elif isinstance(field, BackrefAccessor): + lm, rm = field.model, field.rel_model field_obj = field.field query = query.ensure_join(lm, rm, field_obj) return query.where(dq_node) diff --git a/tests/models.py b/tests/models.py index a4e1323e6..aa006d4c6 100644 --- a/tests/models.py +++ b/tests/models.py @@ -690,6 +690,19 @@ def test_from_multi_table(self): self.assertEqual([t['username'] for t in query], ['huey', 'huey', 'huey']) + @requires_models(User, Tweet) + def test_filtering(self): + self.add_tweets(self.add_user('huey'), 'meow', 'hiss', 'purr') + self.add_tweets(self.add_user('mickey'), 'woof', 'wheeze') + + query = Tweet.filter(user__username='huey').order_by(Tweet.content) + self.assertEqual([row.content for row in query], + ['hiss', 'meow', 'purr']) + + query = User.filter(tweets__content__ilike='w%') + self.assertEqual([user.username for user in query], + ['mickey', 'mickey']) + def test_deferred_fk(self): class Note(TestModel): foo = DeferredForeignKey('Foo', backref='notes') @@ -1538,8 +1551,10 @@ class Admin(BaseUser): self.assertEqual(BaseUser.account.backref, 'baseuser_set') self.assertEqual(User.account.backref, 'user_set') self.assertEqual(Admin.account.backref, 'admin_set') - self.assertTrue(Account.user_set.model is User) - self.assertTrue(Account.admin_set.model is Admin) + self.assertTrue(Account.user_set.model is Account) + self.assertTrue(Account.admin_set.model is Account) + self.assertTrue(Account.user_set.rel_model is User) + self.assertTrue(Account.admin_set.rel_model is Admin) self.assertSQL(Account._schema._create_table(), ( 'CREATE TABLE IF NOT EXISTS "account" (' @@ -1573,9 +1588,12 @@ class Photo(BasePost): pass self.assertEqual(BasePost.category.backref, 'baseposts') self.assertEqual(Note.category.backref, 'notes') self.assertEqual(Photo.category.backref, 'photos') - self.assertTrue(Category.baseposts.model is BasePost) - self.assertTrue(Category.notes.model is Note) - self.assertTrue(Category.photos.model is Photo) + self.assertTrue(Category.baseposts.rel_model is BasePost) + self.assertTrue(Category.baseposts.model is Category) + self.assertTrue(Category.notes.rel_model is Note) + self.assertTrue(Category.notes.model is Category) + self.assertTrue(Category.photos.rel_model is Photo) + self.assertTrue(Category.photos.model is Category) class BaseItem(TestModel): category = ForeignKeyField(Category, backref='items') @@ -1585,9 +1603,11 @@ class ItemB(BaseItem): pass self.assertEqual(BaseItem.category.backref, 'items') self.assertEqual(ItemA.category.backref, 'itema_set') self.assertEqual(ItemB.category.backref, 'itemb_set') - self.assertTrue(Category.items.model is BaseItem) - self.assertTrue(Category.itema_set.model is ItemA) - self.assertTrue(Category.itemb_set.model is ItemB) + self.assertTrue(Category.items.rel_model is BaseItem) + self.assertTrue(Category.itema_set.rel_model is ItemA) + self.assertTrue(Category.itema_set.model is Category) + self.assertTrue(Category.itemb_set.rel_model is ItemB) + self.assertTrue(Category.itemb_set.model is Category) class TestMetaInheritance(BaseTestCase): @@ -1882,6 +1902,11 @@ def test_replace(self): ('nuggie', 'dog', '123')]) +class OCTest(TestModel): + a = CharField(unique=True) + b = IntegerField(default=0) + + @skip_case_unless(IS_POSTGRESQL) class TestUpsertPostgresql(OnConflictTestCase): def test_update(self): @@ -1909,6 +1934,53 @@ def test_update(self): ('zaizee', 'cat', '124'), ('foo', 'baze', '125.1')]) + @requires_models(OCTest) + def test_update_atomic(self): + query = OCTest.insert(a='foo', b=1).on_conflict( + conflict_target=(OCTest.a,), + update={OCTest.b: OCTest.b + 2}) + self.assertSQL(query, ( + 'INSERT INTO "octest" ("a", "b") VALUES (?, ?) ' + 'ON CONFLICT ("a") DO UPDATE SET "b" = ("octest"."b" + ?) ' + 'RETURNING "id"'), ['foo', 1, 2]) + + rowid1 = query.execute() + rowid2 = query.clone().execute() + self.assertEqual(rowid1, rowid2) + + obj = OCTest.get() + self.assertEqual(obj.a, 'foo') + self.assertEqual(obj.b, 3) + + @requires_models(OCTest) + def test_update_where_clause(self): + query = OCTest.insert(a='foo', b=1).on_conflict( + conflict_target=(OCTest.a,), + update={OCTest.b: OCTest.b + 2}, + where=(OCTest.b < 3)) + self.assertSQL(query, ( + 'INSERT INTO "octest" ("a", "b") VALUES (?, ?) ' + 'ON CONFLICT ("a") DO UPDATE SET "b" = ("octest"."b" + ?) ' + 'WHERE ("octest"."b" < ?) ' + 'RETURNING "id"'), ['foo', 1, 2, 3]) + + rowid1 = query.execute() + rowid2 = query.clone().execute() + self.assertEqual(rowid1, rowid2) + + obj = OCTest.get() + self.assertEqual(obj.a, 'foo') + self.assertEqual(obj.b, 3) + + rowid3 = query.clone().execute() + self.assertEqual(rowid1, rowid2) + + # Because we didn't satisfy the WHERE clause, the value in "b" is + # not incremented again. + obj = OCTest.get() + self.assertEqual(obj.a, 'foo') + self.assertEqual(obj.b, 3) + class TestJoinSubquery(ModelTestCase): requires = [Person, Relationship] @@ -1987,3 +2059,30 @@ def test_get_or_create_self_referential_fk2(self): self.assertEqual(child_db.user.username, 'huey') self.assertEqual(child_db.parent.name, 'parent') self.assertEqual(child_db.name, 'child') + + +class TestCountUnionRegression(ModelTestCase): + @requires_models(User) + @skip_if(IS_MYSQL) + def test_count_union(self): + with self.database.atomic(): + for i in range(5): + User.create(username='user-%d' % i) + + lhs = User.select() + rhs = User.select() + query = (lhs & rhs) + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."username" FROM "users" AS "t1" ' + 'INTERSECT ' + 'SELECT "t2"."id", "t2"."username" FROM "users" AS "t2"'), []) + + self.assertEqual(query.count(), 5) + + query = query.limit(3) + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."username" FROM "users" AS "t1" ' + 'INTERSECT ' + 'SELECT "t2"."id", "t2"."username" FROM "users" AS "t2" ' + 'LIMIT 3'), []) + self.assertEqual(query.count(), 3) diff --git a/tests/sql.py b/tests/sql.py index 7fc36622d..a10b2653a 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -847,7 +847,8 @@ def test_update(self): self.assertSQL(query, ( 'INSERT INTO "person" ("dob", "name") VALUES (?, ?) ' 'ON CONFLICT ("name") DO ' - 'UPDATE SET "dob" = EXCLUDED."dob", "name" = ("name" || ?)'), + 'UPDATE SET "dob" = EXCLUDED."dob", ' + '"name" = ("person"."name" || ?)'), [dob, 'huey', '-x']) query = (Person @@ -870,8 +871,9 @@ def test_update(self): self.assertSQL(query, ( 'INSERT INTO "person" ("name") VALUES (?) ' 'ON CONFLICT ("name") DO ' - 'UPDATE SET "dob" = EXCLUDED."dob", "name" = ("name" || ?) ' - 'WHERE ("name" != ?)'), ['huey', '-x', 'zaizee']) + 'UPDATE SET "dob" = EXCLUDED."dob", ' + '"name" = ("person"."name" || ?) ' + 'WHERE ("person"."name" != ?)'), ['huey', '-x', 'zaizee']) #Person = Table('person', ['id', 'name', 'dob']) diff --git a/tests/transactions.py b/tests/transactions.py index 82b60a490..5805c4a9b 100644 --- a/tests/transactions.py +++ b/tests/transactions.py @@ -157,3 +157,7 @@ def also_fails(): with db.manual_commit(): self.assertRaises(ValueError, also_fails) + + def test_closing_db_in_transaction(self): + with db.atomic(): + self.assertRaises(OperationalError, db.close)