Skip to content

Commit

Permalink
Use qualified column names in ON CONFLICT clause. Fixes #1512.
Browse files Browse the repository at this point in the history
  • Loading branch information
coleifer committed Feb 22, 2018
1 parent f40ec03 commit e642ab9
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
10 changes: 9 additions & 1 deletion peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,6 +1396,12 @@ def clone(self):
Tuple = lambda *a: EnclosedNodeList(a)


class QualifiedNames(WrappedNode):
def __sql__(self, ctx):
with ctx.scope_column():
return ctx.sql(self.node)


class OnConflict(Node):
def __init__(self, action=None, update=None, preserve=None, where=None,
conflict_target=None):
Expand Down Expand Up @@ -3062,14 +3068,16 @@ def conflict_update(self, on_conflict):
if not isinstance(v, Node):
converter = k.db_value if isinstance(k, Field) else None
v = Value(v, converter=converter, unpack=False)
else:
v = QualifiedNames(v)
updates.append(NodeList((ensure_entity(k), SQL('='), v)))

parts = [SQL('ON CONFLICT'),
target,
SQL('DO UPDATE SET'),
CommaNodeList(updates)]
if on_conflict._where:
parts.extend((SQL('WHERE'), on_conflict._where))
parts.extend((SQL('WHERE'), QualifiedNames(on_conflict._where)))

return NodeList(parts)

Expand Down
52 changes: 52 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1902,6 +1902,11 @@ def test_replace(self):
('nuggie', 'dog', '123')])


class OCTest(TestModel):
a = CharField(unique=True)
b = IntegerField(default=0)


@skip_case_unless(IS_POSTGRESQL)
class TestUpsertPostgresql(OnConflictTestCase):
def test_update(self):
Expand Down Expand Up @@ -1929,6 +1934,53 @@ def test_update(self):
('zaizee', 'cat', '124'),
('foo', 'baze', '125.1')])

@requires_models(OCTest)
def test_update_atomic(self):
query = OCTest.insert(a='foo', b=1).on_conflict(
conflict_target=(OCTest.a,),
update={OCTest.b: OCTest.b + 2})
self.assertSQL(query, (
'INSERT INTO "octest" ("a", "b") VALUES (?, ?) '
'ON CONFLICT ("a") DO UPDATE SET "b" = ("octest"."b" + ?) '
'RETURNING "id"'), ['foo', 1, 2])

rowid1 = query.execute()
rowid2 = query.clone().execute()
self.assertEqual(rowid1, rowid2)

obj = OCTest.get()
self.assertEqual(obj.a, 'foo')
self.assertEqual(obj.b, 3)

@requires_models(OCTest)
def test_update_where_clause(self):
query = OCTest.insert(a='foo', b=1).on_conflict(
conflict_target=(OCTest.a,),
update={OCTest.b: OCTest.b + 2},
where=(OCTest.b < 3))
self.assertSQL(query, (
'INSERT INTO "octest" ("a", "b") VALUES (?, ?) '
'ON CONFLICT ("a") DO UPDATE SET "b" = ("octest"."b" + ?) '
'WHERE ("octest"."b" < ?) '
'RETURNING "id"'), ['foo', 1, 2, 3])

rowid1 = query.execute()
rowid2 = query.clone().execute()
self.assertEqual(rowid1, rowid2)

obj = OCTest.get()
self.assertEqual(obj.a, 'foo')
self.assertEqual(obj.b, 3)

rowid3 = query.clone().execute()
self.assertEqual(rowid1, rowid2)

# Because we didn't satisfy the WHERE clause, the value in "b" is
# not incremented again.
obj = OCTest.get()
self.assertEqual(obj.a, 'foo')
self.assertEqual(obj.b, 3)


class TestJoinSubquery(ModelTestCase):
requires = [Person, Relationship]
Expand Down
8 changes: 5 additions & 3 deletions tests/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,8 @@ def test_update(self):
self.assertSQL(query, (
'INSERT INTO "person" ("dob", "name") VALUES (?, ?) '
'ON CONFLICT ("name") DO '
'UPDATE SET "dob" = EXCLUDED."dob", "name" = ("name" || ?)'),
'UPDATE SET "dob" = EXCLUDED."dob", '
'"name" = ("person"."name" || ?)'),
[dob, 'huey', '-x'])

query = (Person
Expand All @@ -870,8 +871,9 @@ def test_update(self):
self.assertSQL(query, (
'INSERT INTO "person" ("name") VALUES (?) '
'ON CONFLICT ("name") DO '
'UPDATE SET "dob" = EXCLUDED."dob", "name" = ("name" || ?) '
'WHERE ("name" != ?)'), ['huey', '-x', 'zaizee'])
'UPDATE SET "dob" = EXCLUDED."dob", '
'"name" = ("person"."name" || ?) '
'WHERE ("person"."name" != ?)'), ['huey', '-x', 'zaizee'])


#Person = Table('person', ['id', 'name', 'dob'])
Expand Down

0 comments on commit e642ab9

Please sign in to comment.