Skip to content

Commit

Permalink
Initial support for sqlite jsonb (requires 3.45.0 or newer).
Browse files Browse the repository at this point in the history
Warning: this set of functions requires a good deal of care to use correctly.
  • Loading branch information
coleifer committed Mar 21, 2024
1 parent f8b225c commit adad27b
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 13 deletions.
67 changes: 59 additions & 8 deletions playhouse/sqlite_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __getitem__(self, idx):
item = '[%s]' % idx
else:
item = '.%s' % idx
return JSONPath(self._field, self._path + (item,))
return type(self)(self._field, self._path + (item,))

def append(self, value, as_json=None):
if as_json or isinstance(value, (list, dict)):
Expand Down Expand Up @@ -133,10 +133,41 @@ def __sql__(self, ctx):
return ctx.sql(fn.json_extract(self._field, self.path)
if self._path else self._field)

class JSONBPath(JSONPath):
def append(self, value, as_json=None):
if as_json or isinstance(value, (list, dict)):
value = fn.jsonb(self._field._json_dumps(value))
return fn.jsonb_set(self._field, self['#'].path, value)

def _json_operation(self, func, value, as_json=None):
if as_json or isinstance(value, (list, dict)):
value = fn.jsonb(self._field._json_dumps(value))
return func(self._field, self.path, value)

def insert(self, value, as_json=None):
return self._json_operation(fn.jsonb_insert, value, as_json)

def set(self, value, as_json=None):
return self._json_operation(fn.jsonb_set, value, as_json)

def replace(self, value, as_json=None):
return self._json_operation(fn.jsonb_replace, value, as_json)

def update(self, value):
return self.set(fn.jsonb_patch(self, self._field._json_dumps(value)))

def remove(self):
return fn.jsonb_remove(self._field, self.path)

def __sql__(self, ctx):
return ctx.sql(fn.jsonb_extract(self._field, self.path)
if self._path else self._field)


class JSONField(TextField):
field_type = 'JSON'
unpack = False
Path = JSONPath

def __init__(self, json_dumps=None, json_loads=None, **kwargs):
self._json_dumps = json_dumps or json.dumps
Expand Down Expand Up @@ -171,7 +202,7 @@ def inner(self, rhs):
__hash__ = Field.__hash__

def __getitem__(self, item):
return JSONPath(self)[item]
return self.Path(self)[item]

def extract(self, *paths):
paths = [Value(p, converter=False) for p in paths]
Expand All @@ -182,23 +213,23 @@ def extract_text(self, path):
return Expression(self, '->>', Value(path, converter=False))

def append(self, value, as_json=None):
return JSONPath(self).append(value, as_json)
return self.Path(self).append(value, as_json)

def insert(self, value, as_json=None):
return JSONPath(self).insert(value, as_json)
return self.Path(self).insert(value, as_json)

def set(self, value, as_json=None):
return JSONPath(self).set(value, as_json)
return self.Path(self).set(value, as_json)

def replace(self, value, as_json=None):
return JSONPath(self).replace(value, as_json)
return self.Path(self).replace(value, as_json)

def update(self, data):
return JSONPath(self).update(data)
return self.Path(self).update(data)

def remove(self, *paths):
if not paths:
return JSONPath(self).remove()
return self.Path(self).remove()
return fn.json_remove(self, *paths)

def json_type(self):
Expand Down Expand Up @@ -229,6 +260,26 @@ def tree(self):
return fn.json_tree(self)


class JSONBField(JSONField):
field_type = 'JSONB'
Path = JSONBPath

def db_value(self, value):
if value is not None:
if not isinstance(value, Node):
value = fn.jsonb(self._json_dumps(value))
return value

def extract(self, *paths):
paths = [Value(p, converter=False) for p in paths]
return fn.jsonb_extract(self, *paths)

def remove(self, *paths):
if not paths:
return self.Path(self).remove()
return fn.jsonb_remove(self, *paths)


class SearchField(Field):
def __init__(self, unindexed=False, column_name=None, **k):
if k:
Expand Down
62 changes: 58 additions & 4 deletions tests/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .sqlite_helpers import json_installed
from .sqlite_helpers import json_patch_installed
from .sqlite_helpers import json_text_installed
from .sqlite_helpers import jsonb_installed


database = SqliteExtDatabase(':memory:', c_extensions=False, timeout=100)
Expand Down Expand Up @@ -125,6 +126,10 @@ class KeyData(TestModel):
key = TextField()
data = JSONField()

class JBData(TestModel):
key = TextField()
data = JSONBField()


class Values(TestModel):
klass = IntegerField()
Expand Down Expand Up @@ -532,9 +537,11 @@ class TestJSONFieldFunctions(ModelTestCase):
('d', {'x1': {'y1': 'z1', 'y2': 'z2'}}),
('e', {'l1': [0, 1, 2], 'l2': [1, [3, 3], 7]}),
]
M = KeyData

def setUp(self):
super(TestJSONFieldFunctions, self).setUp()
KeyData = self.M
with self.database.atomic():
for key, data in self.test_data:
KeyData.create(key=key, data=data)
Expand All @@ -545,9 +552,11 @@ def assertRows(self, where, expected):
self.assertEqual([kd.key for kd in self.Q.where(where)], expected)

def assertData(self, key, expected):
KeyData = self.M
self.assertEqual(KeyData.get(KeyData.key == key).data, expected)

def test_json_group_functions(self):
KeyData = self.M
with self.database.atomic():
KeyData.delete().execute()
for i in range(10):
Expand Down Expand Up @@ -597,6 +606,7 @@ def test_json_group_functions(self):
self.assertEqual(query.scalar(), {'k0': 0, 'k1': 1, 'k2': 2, 'k3': 3})

def test_extract(self):
KeyData = self.M
self.assertRows((KeyData.data['k1'] == 'v1'), ['a', 'c'])
self.assertRows((KeyData.data['k2'] == 'v2'), ['b', 'c'])
self.assertRows((KeyData.data['x1']['y1'] == 'z1'), ['a', 'd'])
Expand All @@ -605,6 +615,7 @@ def test_extract(self):

@skip_unless(json_text_installed())
def test_extract_text_json(self):
KeyData = self.M
D = KeyData.data
self.assertRows((D.extract('$.k1') == 'v1'), ['a', 'c'])
self.assertRows((D.extract_text('$.k1') == 'v1'), ['a', 'c'])
Expand All @@ -618,6 +629,7 @@ def test_extract_text_json(self):
self.assertRows((D.extract_json('x1') == '{"y1":"z1"}'), ['a'])

def test_extract_multiple(self):
KeyData = self.M
query = KeyData.select(
KeyData.key,
KeyData.data.extract('$.k1', '$.k2').alias('keys'))
Expand All @@ -629,6 +641,7 @@ def test_extract_multiple(self):
('e', [None, None])])

def test_insert(self):
KeyData = self.M
# Existing values are not overwritten.
query = KeyData.update(data=KeyData.data['k1'].insert('v1-x'))
self.assertEqual(query.execute(), 5)
Expand All @@ -641,6 +654,7 @@ def test_insert(self):
'l2': [1, [3, 3], 7]})

def test_insert_json(self):
KeyData = self.M
set_json = KeyData.data['k1'].insert([0])
query = KeyData.update(data=set_json)
self.assertEqual(query.execute(), 5)
Expand All @@ -653,6 +667,7 @@ def test_insert_json(self):
'l2': [1, [3, 3], 7]})

def test_replace(self):
KeyData = self.M
# Only existing values are overwritten.
query = KeyData.update(data=KeyData.data['k1'].replace('v1-x'))
self.assertEqual(query.execute(), 5)
Expand All @@ -664,6 +679,7 @@ def test_replace(self):
self.assertData('e', {'l1': [0, 1, 2], 'l2': [1, [3, 3], 7]})

def test_replace_json(self):
KeyData = self.M
set_json = KeyData.data['k1'].replace([0])
query = KeyData.update(data=set_json)
self.assertEqual(query.execute(), 5)
Expand All @@ -675,6 +691,7 @@ def test_replace_json(self):
self.assertData('e', {'l1': [0, 1, 2], 'l2': [1, [3, 3], 7]})

def test_set(self):
KeyData = self.M
query = (KeyData
.update({KeyData.data: KeyData.data['k1'].set('v1-x')})
.where(KeyData.data['k1'] == 'v1'))
Expand All @@ -684,6 +701,7 @@ def test_set(self):
self.assertData('a', {'k1': 'v1-x', 'x1': {'y1': 'z1'}})

def test_set_json(self):
KeyData = self.M
set_json = KeyData.data['x1'].set({'y1': 'z1-x', 'y3': 'z3'})
query = (KeyData
.update({KeyData.data: set_json})
Expand All @@ -695,6 +713,7 @@ def test_set_json(self):
self.assertData('d', {'x1': {'y1': 'z1-x', 'y3': 'z3'}})

def test_append(self):
KeyData = self.M
for value in ('ix', [], ['c1'], ['c1', 'c2'], {}, {'k1': 'v1'},
{'k1': 'v1', 'k2': 'v2'}, None, 1):
KeyData.delete().execute()
Expand All @@ -710,7 +729,9 @@ def test_append(self):
.where(KeyData.key.startswith('a')))
self.assertEqual(query.execute(), 3)

query = KeyData.select().where(KeyData.key.startswith('a'))
query = (KeyData
.select(KeyData.key, fn.json(KeyData.data))
.where(KeyData.key.startswith('a')))
self.assertEqual(sorted((row.key, row.data) for row in query),
[('a0', [value]), ('a1', ['i1', value]),
('a2', ['i1', 'i2', value])])
Expand All @@ -720,14 +741,17 @@ def test_append(self):
.where(KeyData.key.startswith('n')))
self.assertEqual(query.execute(), 3)

query = KeyData.select().where(KeyData.key.startswith('n'))
query = (KeyData
.select(KeyData.key, fn.json(KeyData.data))
.where(KeyData.key.startswith('n')))
self.assertEqual(sorted((row.key, row.data) for row in query),
[('n0', {'arr': [value]}),
('n1', {'arr': ['i1', value]}),
('n2', {'arr': ['i1', 'i2', value]})])

@skip_unless(json_patch_installed())
def test_update(self):
KeyData = self.M
merged = KeyData.data.update({'x1': {'y1': 'z1-x', 'y3': 'z3'}})
query = (KeyData
.update({KeyData.data: merged})
Expand All @@ -740,6 +764,7 @@ def test_update(self):

@skip_unless(json_patch_installed())
def test_update_with_removal(self):
KeyData = self.M
m = KeyData.data.update({'k1': None, 'x1': {'y1': None, 'y3': 'z3'}})
query = KeyData.update(data=m).where(KeyData.data['x1']['y1'] == 'z1')
self.assertEqual(query.execute(), 2)
Expand All @@ -750,6 +775,7 @@ def test_update_with_removal(self):

@skip_unless(json_patch_installed())
def test_update_nested(self):
KeyData = self.M
merged = KeyData.data['x1'].update({'y1': 'z1-x', 'y3': 'z3'})
query = (KeyData
.update(data=merged)
Expand All @@ -762,6 +788,7 @@ def test_update_nested(self):

@skip_unless(json_patch_installed())
def test_updated_nested_with_removal(self):
KeyData = self.M
merged = KeyData.data['x1'].update({'o1': 'p1', 'y1': None})
nrows = (KeyData
.update(data=merged)
Expand All @@ -772,6 +799,7 @@ def test_updated_nested_with_removal(self):
self.assertData('d', {'x1': {'o1': 'p1', 'y2': 'z2'}})

def test_remove(self):
KeyData = self.M
query = (KeyData
.update(data=KeyData.data['k1'].remove())
.where(KeyData.data['k1'] == 'v1'))
Expand All @@ -787,14 +815,16 @@ def test_remove(self):
self.assertData('e', {'l1': [0, 1, 2], 'l2': [1, [3], 7]})

def test_simple_update(self):
KeyData = self.M
nrows = (KeyData
.update(data={'foo': 'bar'})
.where(KeyData.key.in_(['a', 'b']))
.execute())
for k in self.Q.where(KeyData.key.in_(['a', 'b'])):
self.assertEqual(k.data, {'foo': 'bar'})
self.assertData('a', {'foo': 'bar'})
self.assertData('b', {'foo': 'bar'})

def test_children(self):
KeyData = self.M
children = KeyData.data.children().alias('children')
query = (KeyData
.select(KeyData.key, children.c.fullkey.alias('fullkey'))
Expand All @@ -809,6 +839,7 @@ def test_children(self):
('e', '$.l1'), ('e', '$.l2')])

def test_tree(self):
KeyData = self.M
tree = KeyData.data.tree().alias('tree')
query = (KeyData
.select(tree.c.fullkey.alias('fullkey'))
Expand All @@ -823,6 +854,29 @@ def test_tree(self):
'$.x1.y2'])


@skip_unless(jsonb_installed(), 'requires sqlite jsonb support')
class TestJSONBFieldFunctions(TestJSONFieldFunctions):
requires = [JBData]
M = JBData

def assertData(self, key, expected):
q = JBData.select(fn.json(JBData.data)).where(JBData.key == key)
self.assertEqual(q.get().data, expected)

def test_extract_multiple(self):
# We need to override this, otherwise we end up with jsonb returned.
expr = fn.json(JBData.data.extract('$.k1', '$.k2'))
query = JBData.select(
JBData.key,
expr.python_value(json.loads).alias('keys'))
self.assertEqual(sorted((k.key, k.keys) for k in query), [
('a', ['v1', None]),
('b', [None, 'v2']),
('c', ['v1', 'v2']),
('d', [None, None]),
('e', [None, None])])


class TestSqliteExtensions(BaseTestCase):
def test_virtual_model(self):
class Test(VirtualModel):
Expand Down
4 changes: 3 additions & 1 deletion tests/sqlite_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
def json_installed():
if sqlite3.sqlite_version_info < (3, 9, 0):
return False
# Test in-memory DB to determine if the FTS5 extension is installed.
tmp_db = sqlite3.connect(':memory:')
try:
tmp_db.execute('select json(?)', (1337,))
Expand All @@ -22,6 +21,9 @@ def json_patch_installed():
def json_text_installed():
return sqlite3.sqlite_version_info >= (3, 38, 0)

def jsonb_installed():
return sqlite3.sqlite_version_info >= (3, 45, 0)


def compile_option(p):
if not hasattr(compile_option, '_pragma_cache'):
Expand Down

0 comments on commit adad27b

Please sign in to comment.