diff --git a/.travis.yml b/.travis.yml index e67b5f6c0..7f1897b3f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -37,9 +37,9 @@ matrix: env: - PEEWEE_TEST_BACKEND=cockroachdb before_install: - - wget -qO- https://binaries.cockroachdb.com/cockroach-v19.2.0.linux-amd64.tgz | tar xvz - - ./cockroach-v19.2.0.linux-amd64/cockroach start --insecure --background - - ./cockroach-v19.2.0.linux-amd64/cockroach sql --insecure -e 'create database peewee_test;' + - wget -qO- https://binaries.cockroachdb.com/cockroach-v20.1.1.linux-amd64.tgz | tar xvz + - ./cockroach-v20.1.1.linux-amd64/cockroach start --insecure --background + - ./cockroach-v20.1.1.linux-amd64/cockroach sql --insecure -e 'create database peewee_test;' allow_failures: addons: postgresql: "9.6" diff --git a/CHANGELOG.md b/CHANGELOG.md index a3d913f3e..ef7cbb25d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,47 @@ https://github.com/coleifer/peewee/releases ## master -[View commits](https://github.com/coleifer/peewee/compare/3.13.3...master) +[View commits](https://github.com/coleifer/peewee/compare/3.14.0...master) + +## 3.14.0 + +This release has been a bit overdue and there are numerous small improvements +and bug-fixes. The bugfix that prompted this release is #2293, which is a +regression in the Django-inspired `.filter()` APIs that could cause some +filter expressions to be discarded from the generated SQL. Many thanks for the +excellent bug report, Jakub. + +* Add an experimental helper, `shortcuts.resolve_multimodel_query()`, for + resolving multiple models used in a compound select query. +* Add a `lateral()` method to select query for use with lateral joins, refs + issue #2205. +* Added support for nested transactions (savepoints) in cockroach-db (requires + 20.1 or newer). +* Automatically escape wildcards passed to string-matching methods, refs #2224. +* Allow index-type to be specified on MySQL, refs #2242. +* Added a new API, `converter()` to be used for specifying a function to use to + convert a row-value pulled off the cursor, refs #2248. +* Add `set()` and `clear()` method to the bitfield flag descriptor, refs #2257. +* Add support for `range` types with `IN` and other expressions. +* Support CTEs bound to compound select queries, refs #2289. + +### Bug-fixes + +* Fix to return related object id when accessing via the object-id descriptor, + when the related object is not populated, refs #2162. +* Fix to ensure we do not insert a NULL value for a primary key. +* Fix to conditionally set the field/column on an added column in a migration, + refs #2171. +* Apply field conversion logic to model-class values. Relocates the logic from + issue #2131 and fixes #2185. +* Clone node before modifying it to be flat in an enclosed nodelist expr, fixes + issue #2200. +* Fix an invalid item assignment in nodelist, refs #2220. +* Fix an incorrect truthiness check used with `save()` and `only=`, refs #2269. +* Fix regression in `filter()` where using both `*args` and `**kwargs` caused + the expressions passed as `args` to be discarded. See #2293. + +[View commits](https://github.com/coleifer/peewee/compare/3.13.3...3.14.0) ## 3.13.3 diff --git a/docs/peewee/api.rst b/docs/peewee/api.rst index c1b825ab8..d749a481d 100644 --- a/docs/peewee/api.rst +++ b/docs/peewee/api.rst @@ -2916,6 +2916,28 @@ Fields # Query for sticky + favorite posts: query = Post.select().where(Post.is_sticky & Post.is_favorite) + When bulk-updating one or more bits in a :py:class:`BitField`, you can use + bitwise operators to set or clear one or more bits: + + .. code-block:: python + + # Set the 4th bit on all Post objects. + Post.update(flags=Post.flags | 8).execute() + + # Clear the 1st and 3rd bits on all Post objects. + Post.update(flags=Post.flags & ~(1 | 4)).execute() + + For simple operations, the flags provide handy ``set()`` and ``clear()`` + methods for setting or clearing an individual bit: + + .. code-block:: python + + # Set the "is_deleted" bit on all posts. + Post.update(flags=Post.is_deleted.set()).execute() + + # Clear the "is_deleted" bit on all posts. + Post.update(flags=Post.is_deleted.clear()).execute() + .. py:method:: flag([value=None]) :param int value: Value associated with flag, typically a power of 2. @@ -4029,9 +4051,9 @@ Model or newer to take advantage of bulk inserts. .. note:: - SQLite has a default limit of 999 bound variables per statement. - This limit can be modified at compile-time or at run-time, **but** - if modifying at run-time, you can only specify a *lower* value than + SQLite has a default limit of bound variables per statement. This + limit can be modified at compile-time or at run-time, **but** if + modifying at run-time, you can only specify a *lower* value than the default limit. For more information, check out the following SQLite documents: @@ -4181,8 +4203,10 @@ Model * The primary-key value for the newly-created models will only be set if you are using Postgresql (which supports the ``RETURNING`` clause). - * SQLite generally has a limit of 999 bound parameters for a query, - so the batch size should be roughly 1000 / number-of-fields. + * SQLite generally has a limit of bound parameters for a query, + so the maximum batch size should be param-limit / number-of-fields. + This limit is typically 999 for Sqlite < 3.32.0, and 32766 for + newer versions. * When a batch-size is provided it is **strongly recommended** that you wrap the call in a transaction or savepoint using :py:meth:`Database.atomic`. Otherwise an error in a batch mid-way @@ -4225,7 +4249,9 @@ Model .. warning:: - * SQLite generally has a limit of 999 bound parameters for a query. + * SQLite generally has a limit of bound parameters for a query. + This limit is typically 999 for Sqlite < 3.32.0, and 32766 for + newer versions. * When a batch-size is provided it is **strongly recommended** that you wrap the call in a transaction or savepoint using :py:meth:`Database.atomic`. Otherwise an error in a batch mid-way diff --git a/docs/peewee/database.rst b/docs/peewee/database.rst index 99d796ef5..febd2e42e 100644 --- a/docs/peewee/database.rst +++ b/docs/peewee/database.rst @@ -1315,7 +1315,7 @@ The connection handling code can be placed in a `middleware component def process_request(self, req, resp): database.connect() - def process_response(self, req, resp, resource): + def process_response(self, req, resp, resource, req_succeeded): if not database.is_closed(): database.close() diff --git a/docs/peewee/hacks.rst b/docs/peewee/hacks.rst index 3c79780ce..4c64e6bf5 100644 --- a/docs/peewee/hacks.rst +++ b/docs/peewee/hacks.rst @@ -438,7 +438,7 @@ some code that will tell us which tasks we should run at a given time: command = TextField() # Run this command. last_run = DateTimeField() # When was this run last? -Our logic will essentially boil down to:: +Our logic will essentially boil down to: .. code-block:: python diff --git a/docs/peewee/models.rst b/docs/peewee/models.rst index fef187e43..49079fa16 100644 --- a/docs/peewee/models.rst +++ b/docs/peewee/models.rst @@ -507,6 +507,28 @@ storing arbitrarily large bitmaps, you can instead use :py:class:`BigBitField`, which uses an automatically managed buffer of bytes, stored in a :py:class:`BlobField`. +When bulk-updating one or more bits in a :py:class:`BitField`, you can use +bitwise operators to set or clear one or more bits: + +.. code-block:: python + + # Set the 4th bit on all Post objects. + Post.update(flags=Post.flags | 8).execute() + + # Clear the 1st and 3rd bits on all Post objects. + Post.update(flags=Post.flags & ~(1 | 4)).execute() + +For simple operations, the flags provide handy ``set()`` and ``clear()`` +methods for setting or clearing an individual bit: + +.. code-block:: python + + # Set the "is_deleted" bit on all posts. + Post.update(flags=Post.is_deleted.set()).execute() + + # Clear the "is_deleted" bit on all posts. + Post.update(flags=Post.is_deleted.clear()).execute() + Example usage: .. code-block:: python diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index 90275e249..6abe60ae8 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -685,7 +685,7 @@ currently: * :py:class:`JSONField` field type, for storing JSON data. * :py:class:`BinaryJSONField` field type for the ``jsonb`` JSON data type. * :py:class:`TSVectorField` field type, for storing full-text search data. -* :py:class:`DateTimeTZ` field type, a timezone-aware datetime field. +* :py:class:`DateTimeTZField` field type, a timezone-aware datetime field. In the future I would like to add support for more of postgresql's features. If there is a particular feature you would like to see added, please @@ -2548,6 +2548,20 @@ helpers for serializing models to dictionaries and vice-versa. Update a model instance with the given data dictionary. + +.. py:function:: resolve_multimodel_query(query[, key='_model_identifier']) + + :param query: a compound select query. + :param str key: key to use for storing model identifier + :return: an iteratable cursor that yields the proper model instance for + each row selected in the compound select query. + + Helper for resolving rows returned in a compound select query to the + correct model instance type. For example, if you have a union of two + different tables, this helper will resolve each row to the proper model + when iterating over the query results. + + .. _signals: Signal support diff --git a/docs/peewee/querying.rst b/docs/peewee/querying.rst index 6f6a9f8f0..f22a4d0d2 100644 --- a/docs/peewee/querying.rst +++ b/docs/peewee/querying.rst @@ -158,14 +158,15 @@ It is also a good practice to wrap the bulk insert in a transaction: SQLite users should be aware of some caveats when using bulk inserts. Specifically, your SQLite3 version must be 3.7.11.0 or newer to take advantage of the bulk insert API. Additionally, by default SQLite limits - the number of bound variables in a SQL query to ``999``. + the number of bound variables in a SQL query to ``999`` for SQLite versions + prior to 3.32.0 (2020-05-22) and 32766 for SQLite versions after 3.32.0. Inserting rows in batches ^^^^^^^^^^^^^^^^^^^^^^^^^ Depending on the number of rows in your data source, you may need to break it -up into chunks. SQLite in particular typically has a `limit of 999 `_ -variables-per-query (batch size would then be roughly 1000 / row length). +up into chunks. SQLite in particular typically has a `limit of 999 or 32766 `_ +variables-per-query (batch size would then be 999 // row length or 32766 // row length). You can write a loop to batch your data into chunks (in which case it is **strongly recommended** you use a transaction): @@ -1054,7 +1055,7 @@ MySQL uses *Rand*: .. code-block:: python # Pick 5 lucky winners: - LotterNumber.select().order_by(fn.Rand()).limit(5) + LotteryNumber.select().order_by(fn.Rand()).limit(5) Paginating records ------------------ @@ -1940,7 +1941,7 @@ recursive CTE: level = Value(1).alias('level') path = Base.name.alias('path') base_case = (Base - .select(Base.name, Base.parent, level, path) + .select(Base.id, Base.name, Base.parent, level, path) .where(Base.parent.is_null()) .cte('base', recursive=True)) @@ -1949,7 +1950,7 @@ recursive CTE: rlevel = (base_case.c.level + 1).alias('level') rpath = base_case.c.path.concat('->').concat(RTerm.name).alias('path') recursive = (RTerm - .select(RTerm.name, RTerm.parent, rlevel, rpath) + .select(RTerm.id, RTerm.name, RTerm.parent, rlevel, rpath) .join(base_case, on=(RTerm.parent == base_case.c.id))) # The recursive CTE is created by taking the base case and UNION ALL with diff --git a/docs/peewee/sqlite_ext.rst b/docs/peewee/sqlite_ext.rst index be5fff8ea..0de772145 100644 --- a/docs/peewee/sqlite_ext.rst +++ b/docs/peewee/sqlite_ext.rst @@ -39,7 +39,7 @@ Instantiating a :py:class:`SqliteExtDatabase`: db = SqliteExtDatabase('my_app.db', pragmas=( ('cache_size', -1024 * 64), # 64MB page-cache. ('journal_mode', 'wal'), # Use WAL-mode (you should always use this!). - ('foreign_keys', 1)) # Enforce foreign-key constraints. + ('foreign_keys', 1))) # Enforce foreign-key constraints. APIs ---- diff --git a/examples/analytics/app.py b/examples/analytics/app.py index b7b81f815..05168b489 100644 --- a/examples/analytics/app.py +++ b/examples/analytics/app.py @@ -24,16 +24,17 @@ """ import datetime import os -from urlparse import parse_qsl, urlparse +from urllib.parse import parse_qsl, urlparse +import binascii -from flask import Flask, Response, abort, request, g +from flask import Flask, Response, abort, g, request from peewee import * -from playhouse.postgres_ext import HStoreField -from playhouse.postgres_ext import PostgresqlExtDatabase - +from playhouse.postgres_ext import HStoreField, PostgresqlExtDatabase # Analytics settings. -BEACON = '47494638396101000100800000dbdfef00000021f90401000000002c00000000010001000002024401003b'.decode('hex') # 1px gif. +# 1px gif. +BEACON = binascii.unhexlify( + '47494638396101000100800000dbdfef00000021f90401000000002c00000000010001000002024401003b') DATABASE_NAME = 'analytics' DOMAIN = 'http://analytics.yourdomain.com' # TODO: change me. JAVASCRIPT = """(function(id){ @@ -53,10 +54,12 @@ register_hstore=True, user='postgres') + class BaseModel(Model): class Meta: database = database + class Account(BaseModel): domain = CharField() @@ -65,6 +68,7 @@ def verify_url(self, url): url_domain = '.'.join(netloc.split('.')[-2:]) # Ignore subdomains. return self.domain == url_domain + class PageView(BaseModel): account = ForeignKeyField(Account, backref='pageviews') url = TextField() @@ -89,6 +93,7 @@ def create_from_request(cls, account, request): headers=dict(request.headers), params=params) + @app.route('/a.gif') def analyze(): # Make sure an account id and url were specified. @@ -113,6 +118,7 @@ def analyze(): response.headers['Cache-Control'] = 'private, no-cache' return response + @app.route('/a.js') def script(): account_id = request.args.get('id') @@ -122,17 +128,21 @@ def script(): mimetype='text/javascript') return Response('', mimetype='text/javascript') + @app.errorhandler(404) def not_found(e): return Response('

Not found.

') # Request handlers -- these two hooks are provided by flask and we will use them # to create and tear down a database connection on each request. + + @app.before_request def before_request(): g.db = database g.db.connection() + @app.after_request def after_request(response): g.db.close() diff --git a/examples/analytics/requirements.txt b/examples/analytics/requirements.txt index 9a6384279..6b642d229 100644 --- a/examples/analytics/requirements.txt +++ b/examples/analytics/requirements.txt @@ -1,2 +1,3 @@ peewee flask +psycopg2 diff --git a/examples/diary.py b/examples/diary.py index 943e0aefb..b1188cab3 100755 --- a/examples/diary.py +++ b/examples/diary.py @@ -12,6 +12,7 @@ # command-line. db = SqlCipherDatabase(None) + class Entry(Model): content = TextField() timestamp = DateTimeField(default=datetime.datetime.now) @@ -19,27 +20,31 @@ class Entry(Model): class Meta: database = db + def initialize(passphrase): - db.init('diary.db', passphrase=passphrase, kdf_iter=64000) + db.init('diary.db', passphrase=passphrase) db.create_tables([Entry]) + def menu_loop(): choice = None while choice != 'q': for key, value in menu.items(): print('%s) %s' % (key, value.__doc__)) - choice = raw_input('Action: ').lower().strip() + choice = input('Action: ').lower().strip() if choice in menu: menu[choice]() + def add_entry(): """Add entry""" print('Enter your entry. Press ctrl+d when finished.') data = sys.stdin.read().strip() - if data and raw_input('Save entry? [Yn] ') != 'n': + if data and input('Save entry? [Yn] ') != 'n': Entry.create(content=data) print('Saved successfully.') + def view_entries(search_query=None): """View previous entries""" query = Entry.select().order_by(Entry.timestamp.desc()) @@ -54,16 +59,18 @@ def view_entries(search_query=None): print('n) next entry') print('d) delete entry') print('q) return to main menu') - action = raw_input('Choice? (Ndq) ').lower().strip() + action = input('Choice? (Ndq) ').lower().strip() if action == 'q': break elif action == 'd': entry.delete_instance() break + def search_entries(): """Search entries""" - view_entries(raw_input('Search query: ')) + view_entries(input('Search query: ')) + menu = OrderedDict([ ('a', add_entry), diff --git a/peewee.py b/peewee.py index 1ed95b034..4279fe6e5 100644 --- a/peewee.py +++ b/peewee.py @@ -65,7 +65,7 @@ mysql = None -__version__ = '3.13.3' +__version__ = '3.14.0' __all__ = [ 'AsIs', 'AutoField', @@ -159,6 +159,7 @@ def emit(self, record): buffer_type = buffer izip_longest = itertools.izip_longest callable_ = callable + multi_types = (list, tuple, frozenset, set) exec('def reraise(tp, value, tb=None): raise tp, value, tb') def print_(s): sys.stdout.write(s) @@ -176,6 +177,7 @@ def print_(s): buffer_type = memoryview basestring = str long = int + multi_types = (list, tuple, frozenset, set, range) print_ = getattr(builtins, 'print') izip_longest = itertools.zip_longest def reraise(tp, value, tb=None): @@ -616,8 +618,6 @@ def literal(self, keyword): def value(self, value, converter=None, add_param=True): if converter: value = converter(value) - if isinstance(value, Node): - return self.sql(value) elif converter is None and self.state.converter: # Explicitly check for None so that "False" can be used to signify # that no conversion should be applied. @@ -626,6 +626,13 @@ def value(self, value, converter=None, add_param=True): if isinstance(value, Node): with self(converter=None): return self.sql(value) + elif is_model(value): + # Under certain circumstances, we could end-up treating a model- + # class itself as a value. This check ensures that we drop the + # table alias into the query instead of trying to parameterize a + # model (for instance, passing a model as a function argument). + with self.scope_column(): + return self.sql(value) self._values.append(value) return self.literal(self.state.param or '?') if add_param else self @@ -1098,6 +1105,12 @@ def __sql__(self, ctx): class ColumnBase(Node): + _converter = None + + @Node.copy + def converter(self, converter=None): + self._converter = converter + def alias(self, alias): if alias: return Alias(self, alias) @@ -1172,24 +1185,30 @@ def __ne__(self, rhs): def is_null(self, is_null=True): op = OP.IS if is_null else OP.IS_NOT return Expression(self, op, None) + + def _escape_like_expr(self, s, template): + if s.find('_') >= 0 or s.find('%') >= 0 or s.find('\\') >= 0: + s = s.replace('\\', '\\\\').replace('_', '\\_').replace('%', '\\%') + return NodeList((template % s, SQL('ESCAPE'), '\\')) + return template % s def contains(self, rhs): if isinstance(rhs, Node): rhs = Expression('%', OP.CONCAT, Expression(rhs, OP.CONCAT, '%')) else: - rhs = '%%%s%%' % rhs + rhs = self._escape_like_expr(rhs, '%%%s%%') return Expression(self, OP.ILIKE, rhs) def startswith(self, rhs): if isinstance(rhs, Node): rhs = Expression(rhs, OP.CONCAT, '%') else: - rhs = '%s%%' % rhs + rhs = self._escape_like_expr(rhs, '%s%%') return Expression(self, OP.ILIKE, rhs) def endswith(self, rhs): if isinstance(rhs, Node): rhs = Expression('%', OP.CONCAT, rhs) else: - rhs = '%%%s' % rhs + rhs = self._escape_like_expr(rhs, '%%%s') return Expression(self, OP.ILIKE, rhs) def between(self, lo, hi): return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi))) @@ -1243,6 +1262,7 @@ class WrappedNode(ColumnBase): def __init__(self, node): self.node = node self._coerce = getattr(node, '_coerce', True) + self._converter = getattr(node, '_converter', None) def is_alias(self): return self.node.is_alias() @@ -1334,12 +1354,10 @@ def __sql__(self, ctx): class Value(ColumnBase): - _multi_types = (list, tuple, frozenset, set) - def __init__(self, value, converter=None, unpack=True): self.value = value self.converter = converter - self.multi = isinstance(self.value, self._multi_types) and unpack + self.multi = unpack and isinstance(self.value, multi_types) if self.multi: self.values = [] for item in self.value: @@ -1353,14 +1371,6 @@ def __sql__(self, ctx): # For multi-part values (e.g. lists of IDs). return ctx.sql(EnclosedNodeList(self.values)) - # Under certain circumstances, we could end-up treating a model-class - # itself as a value. This check ensures that we drop the table alias - # into the query instead of trying to parameterize a model (for - # instance, when passing a model as a function argument). - if is_model(self.value): - with ctx.scope_column(): - return ctx.sql(self.value) - return ctx.value(self.value, self.converter) @@ -1736,10 +1746,12 @@ def __init__(self, nodes, glue=' ', parens=False): self.nodes = nodes self.glue = glue self.parens = parens - if parens and len(self.nodes) == 1: - if isinstance(self.nodes[0], Expression): - # Hack to avoid double-parentheses. - self.nodes[0].flat = True + if parens and len(self.nodes) == 1 and \ + isinstance(self.nodes[0], Expression) and \ + not self.nodes[0].flat: + # Hack to avoid double-parentheses. + self.nodes = (self.nodes[0].clone(),) + self.nodes[0].flat = True def __sql__(self, ctx): n_nodes = len(self.nodes) @@ -2222,6 +2234,9 @@ def __sql__(self, ctx): if ctx.scope == SCOPE_COLUMN: return self.apply_column(ctx) + # Call parent method to handle any CTEs. + super(CompoundSelectQuery, self).__sql__(ctx) + outer_parens = ctx.subquery or (ctx.scope == SCOPE_SOURCE) with ctx(parentheses=outer_parens): # Should the left-hand query be wrapped in parentheses? @@ -2248,7 +2263,7 @@ def __sql__(self, ctx): class Select(SelectBase): def __init__(self, from_list=None, columns=None, group_by=None, having=None, distinct=None, windows=None, for_update=None, - for_update_of=None, nowait=None, **kwargs): + for_update_of=None, nowait=None, lateral=None, **kwargs): super(Select, self).__init__(**kwargs) self._from_list = (list(from_list) if isinstance(from_list, tuple) else from_list) or [] @@ -2259,6 +2274,7 @@ def __init__(self, from_list=None, columns=None, group_by=None, self._for_update = for_update # XXX: consider reorganizing. self._for_update_of = for_update_of self._for_update_nowait = nowait + self._lateral = lateral self._distinct = self._simple_distinct = None if distinct: @@ -2341,6 +2357,10 @@ def for_update(self, for_update=True, of=None, nowait=None): self._for_update_of = of self._for_update_nowait = nowait + @Node.copy + def lateral(self, lateral=True): + self._lateral = lateral + def _get_query_key(self): return self._alias @@ -2351,6 +2371,9 @@ def __sql__(self, ctx): if ctx.scope == SCOPE_COLUMN: return self.apply_column(ctx) + if self._lateral and ctx.scope == SCOPE_SOURCE: + ctx.literal('LATERAL ') + is_subquery = ctx.subquery state = { 'converter': None, @@ -2772,13 +2795,19 @@ def __sql__(self, ctx): index_name = Entity(self._name) table_name = self._table + ctx.sql(index_name) + if self._using is not None and \ + ctx.state.index_using_precedes_table: + ctx.literal(' USING %s' % self._using) # MySQL style. + (ctx - .sql(index_name) .literal(' ON ') .sql(table_name) .literal(' ')) - if self._using is not None: - ctx.literal('USING %s ' % self._using) + + if self._using is not None and not \ + ctx.state.index_using_precedes_table: + ctx.literal('USING %s ' % self._using) # Postgres/default. ctx.sql(EnclosedNodeList([ SQL(expr) if isinstance(expr, basestring) else expr @@ -2950,6 +2979,7 @@ class Database(_callable_context_manager): compound_select_parentheses = CSQ_PARENTHESES_NEVER for_update = False index_schema_prefix = False + index_using_precedes_table = False limit_max = None nulls_ordering = False returning_clause = False @@ -2973,7 +3003,7 @@ def __init__(self, database, thread_safe=True, autorollback=False, self.thread_safe = thread_safe if thread_safe: self._state = _ConnectionLocal() - self._lock = threading.Lock() + self._lock = threading.RLock() else: self._state = _ConnectionState() self._lock = _NoopLock() @@ -3122,6 +3152,7 @@ def get_context_options(self): 'conflict_update': self.conflict_update, 'for_update': self.for_update, 'index_schema_prefix': self.index_schema_prefix, + 'index_using_precedes_table': self.index_using_precedes_table, 'limit_max': self.limit_max, 'nulls_ordering': self.nulls_ordering, } @@ -3912,6 +3943,7 @@ class MySQLDatabase(Database): commit_select = True compound_select_parentheses = CSQ_PARENTHESES_UNNESTED for_update = True + index_using_precedes_table = True limit_max = 2 ** 64 - 1 safe_create_index = False safe_drop_index = False @@ -4417,7 +4449,12 @@ def __init__(self, field): def __get__(self, instance, instance_type=None): if instance is not None: - return instance.__data__.get(self.field.name) + value = instance.__data__.get(self.field.name) + # Pull the object-id from the related object if it is not set. + if value is None and self.field.name in instance.__rel__: + rel_obj = instance.__rel__[self.field.name] + value = getattr(rel_obj, self.field.rel_field.name) + return value return self.field def __set__(self, instance, value): @@ -4717,13 +4754,18 @@ def flag(self, value=None): else: self.__current_flag = value << 1 - class FlagDescriptor(object): + class FlagDescriptor(ColumnBase): def __init__(self, field, value): self._field = field self._value = value + super(FlagDescriptor, self).__init__() + def clear(self): + return self._field.bin_and(~self._value) + def set(self): + return self._field.bin_or(self._value) def __get__(self, instance, instance_type=None): if instance is None: - return self._field.bin_and(self._value) != 0 + return self value = getattr(instance, self._field.name) or 0 return (value & self._value) != 0 def __set__(self, instance, is_set): @@ -4735,6 +4777,8 @@ def __set__(self, instance, is_set): else: value &= ~self._value setattr(instance, self._field.name, value) + def __sql__(self, ctx): + return ctx.sql(self._field.bin_and(self._value) != 0) return FlagDescriptor(self, value) @@ -6185,12 +6229,14 @@ def __enter__(self): self._orig_database = [] for model in self.models: self._orig_database.append(model._meta.database) - model.bind(self.database, self.bind_refs, self.bind_backrefs) + model.bind(self.database, self.bind_refs, self.bind_backrefs, + _exclude=set(self.models)) return self.models def __exit__(self, exc_type, exc_val, exc_tb): for model, db in zip(self.models, self._orig_database): - model.bind(db, self.bind_refs, self.bind_backrefs) + model.bind(db, self.bind_refs, self.bind_backrefs, + _exclude=set(self.models)) class Model(with_metaclass(ModelBase, Node)): @@ -6475,7 +6521,7 @@ def save(self, force_insert=False, only=None): pk_value = self._pk else: pk_field = pk_value = None - if only: + if only is not None: field_dict = self._prune_fields(field_dict, only) elif self._meta.only_save_dirty and not force_insert: field_dict = self._prune_fields(field_dict, self.dirty_fields) @@ -6486,6 +6532,9 @@ def save(self, force_insert=False, only=None): self._populate_unsaved_relations(field_dict) rows = 1 + if self._meta.auto_increment and pk_value is None: + field_dict.pop(pk_field.name, None) + if pk_value is not None and not force_insert: if self._meta.composite_key: for pk_part_name in pk_field.field_names: @@ -6562,13 +6611,17 @@ def __sql__(self, ctx): converter=self._meta.primary_key.db_value)) @classmethod - def bind(cls, database, bind_refs=True, bind_backrefs=True): + def bind(cls, database, bind_refs=True, bind_backrefs=True, _exclude=None): is_different = cls._meta.database is not database cls._meta.set_database(database) if bind_refs or bind_backrefs: + if _exclude is None: + _exclude = set() G = cls._meta.model_graph(refs=bind_refs, backrefs=bind_backrefs) for _, model, is_backref in G: - model._meta.set_database(database) + if model not in _exclude: + model._meta.set_database(database) + _exclude.add(model) return is_different @classmethod @@ -7092,11 +7145,16 @@ def convert_dict_to_node(self, qdict): def filter(self, *args, **kwargs): # normalize args and kwargs into a new expression - dq_node = ColumnBase() - if args: - dq_node &= reduce(operator.and_, [a.clone() for a in args]) - if kwargs: - dq_node &= DQ(**kwargs) + if args and kwargs: + dq_node = (reduce(operator.and_, [a.clone() for a in args]) & + DQ(**kwargs)) + elif args: + dq_node = (reduce(operator.and_, [a.clone() for a in args]) & + ColumnBase()) + elif kwargs: + dq_node = DQ(**kwargs) & ColumnBase() + else: + return self.clone() # dq_node should now be an Expression, lhs = Node(), rhs = ... q = collections.deque([dq_node]) @@ -7122,7 +7180,8 @@ def filter(self, *args, **kwargs): else: q.append(piece) - dq_node = dq_node.rhs + if not args or not kwargs: + dq_node = dq_node.lhs query = self.clone() for field in dq_joins: @@ -7325,6 +7384,8 @@ def _initialize_columns(self): fields[idx] = node if not raw_node.is_alias(): self.columns[idx] = node.name + elif isinstance(node, ColumnBase) and raw_node._converter: + converters[idx] = raw_node._converter elif isinstance(node, Function) and node._coerce: if node._python_value is not None: converters[idx] = node._python_value diff --git a/playhouse/cockroachdb.py b/playhouse/cockroachdb.py index afb110762..1d4ff9f68 100644 --- a/playhouse/cockroachdb.py +++ b/playhouse/cockroachdb.py @@ -4,7 +4,6 @@ from peewee import * from peewee import _atomic from peewee import _manual -from peewee import _transaction from peewee import ColumnMetadata # (name, data_type, null, primary_key, table, default) from peewee import ForeignKeyMetadata # (column, dest_table, dest_column, table). from peewee import IndexMetadata @@ -18,6 +17,8 @@ ArrayField = BinaryJSONField = IntervalField = JSONField = None +NESTED_TX_MIN_VERSION = 200100 + TXN_ERR_MSG = ('CockroachDB does not support nested transactions. You may ' 'alternatively use the @transaction context-manager/decorator, ' 'which only wraps the outer-most block in transactional logic. ' @@ -56,6 +57,7 @@ class CockroachDatabase(PostgresqlDatabase): for_update = False nulls_ordering = False + release_after_rollback = True def __init__(self, *args, **kwargs): kwargs.setdefault('user', 'root') @@ -66,7 +68,7 @@ def _set_server_version(self, conn): curs = conn.cursor() curs.execute('select version()') raw, = curs.fetchone() - match_obj = re.match('^CockroachDB.+?v(\d+)\.(\d+)\.(\d+)', raw) + match_obj = re.match(r'^CockroachDB.+?v(\d+)\.(\d+)\.(\d+)', raw) if match_obj is not None: clean = '%d%02d%02d' % tuple(int(i) for i in match_obj.groups()) self.server_version = int(clean) # 19.1.5 -> 190105. @@ -137,13 +139,14 @@ def begin(self, system_time=None, priority=None): commit=False) def atomic(self, system_time=None, priority=None): - return _crdb_atomic(self, system_time, priority) - - def transaction(self, system_time=None, priority=None): - return _transaction(self, system_time, priority) + if self.server_version < NESTED_TX_MIN_VERSION: + return _crdb_atomic(self, system_time, priority) + return super(CockroachDatabase, self).atomic(system_time, priority) def savepoint(self): - raise NotImplementedError(TXN_ERR_MSG) + if self.server_version < NESTED_TX_MIN_VERSION: + raise NotImplementedError(TXN_ERR_MSG) + return super(CockroachDatabase, self).savepoint() def retry_transaction(self, max_attempts=None, system_time=None, priority=None): diff --git a/playhouse/migrate.py b/playhouse/migrate.py index a056aec49..f4b1d5414 100644 --- a/playhouse/migrate.py +++ b/playhouse/migrate.py @@ -234,7 +234,12 @@ def alter_add_column(self, table, column_name, field): # Make field null at first. ctx = self.make_context() field_null, field.null = field.null, True - field.name = field.column_name = column_name + + # Set the field's column-name and name, if it is not set or doesn't + # match the new value. + if field.column_name != column_name: + field.name = field.column_name = column_name + (self ._alter_table(ctx, table) .literal(' ADD COLUMN ') diff --git a/playhouse/shortcuts.py b/playhouse/shortcuts.py index 1772cf1d3..cefa6e437 100644 --- a/playhouse/shortcuts.py +++ b/playhouse/shortcuts.py @@ -1,5 +1,6 @@ from peewee import * from peewee import Alias +from peewee import CompoundSelectQuery from peewee import SENTINEL from peewee import callable_ @@ -226,3 +227,26 @@ def execute_sql(self, sql, params=None, commit=SENTINEL): self.connect() return super(ReconnectMixin, self).execute_sql(sql, params, commit) + + +def resolve_multimodel_query(query, key='_model_identifier'): + mapping = {} + accum = [query] + while accum: + curr = accum.pop() + if isinstance(curr, CompoundSelectQuery): + accum.extend((curr.lhs, curr.rhs)) + continue + + model_class = curr.model + name = model_class._meta.table_name + mapping[name] = model_class + curr._returning.append(Value(name).alias(key)) + + def wrapped_iterator(): + for row in query.dicts().iterator(): + identifier = row.pop(key) + model = mapping[identifier] + yield model(**row) + + return wrapped_iterator() diff --git a/setup.py b/setup.py index 00c242536..3897c9f86 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ def _have_sqlite_extension_support(): customize_compiler(compiler) success = False try: - compiler.link_executable( + compiler.link_shared_object( compiler.compile([src_file], output_dir=tmp_dir), bin_file, libraries=['sqlite3']) diff --git a/tests/base.py b/tests/base.py index c7b782e91..91ee8700b 100644 --- a/tests/base.py +++ b/tests/base.py @@ -14,6 +14,7 @@ from peewee import sqlite3 from playhouse.mysql_ext import MySQLConnectorDatabase from playhouse.cockroachdb import CockroachDatabase +from playhouse.cockroachdb import NESTED_TX_MIN_VERSION logger = logging.getLogger('peewee') @@ -109,6 +110,13 @@ def new_connection(**kwargs): if not IS_MYSQL_ADVANCED_FEATURES: logger.warning('MySQL too old to test certain advanced features.') +if IS_CRDB: + db.connect() + IS_CRDB_NESTED_TX = db.server_version >= NESTED_TX_MIN_VERSION + db.close() +else: + IS_CRDB_NESTED_TX = False + class TestModel(Model): class Meta: diff --git a/tests/cockroachdb.py b/tests/cockroachdb.py index a3bc759b7..021125bb2 100644 --- a/tests/cockroachdb.py +++ b/tests/cockroachdb.py @@ -152,7 +152,7 @@ def insert_row(db): KV.create(k='k1', v=1) with self.database.atomic(): - self.assertRaises(NotImplementedError, run_transaction, + self.assertRaises(Exception, run_transaction, self.database, insert_row) self.assertEqual(KV.select().count(), 0) diff --git a/tests/db_tests.py b/tests/db_tests.py index 24486c6a9..106426f50 100644 --- a/tests/db_tests.py +++ b/tests/db_tests.py @@ -242,6 +242,39 @@ class B(Base): db.close() alt_db.close() + def test_bind_regression(self): + class Base(Model): + class Meta: + database = None + + class A(Base): pass + class B(Base): pass + class AB(Base): + a = ForeignKeyField(A) + b = ForeignKeyField(B) + + self.assertTrue(A._meta.database is None) + + db = get_in_memory_db() + with db.bind_ctx([A, B]): + self.assertEqual(A._meta.database, db) + self.assertEqual(B._meta.database, db) + self.assertEqual(AB._meta.database, db) + + self.assertTrue(A._meta.database is None) + self.assertTrue(B._meta.database is None) + self.assertTrue(AB._meta.database is None) + + class C(Base): + a = ForeignKeyField(A) + + with db.bind_ctx([C], bind_refs=False): + self.assertEqual(C._meta.database, db) + self.assertTrue(A._meta.database is None) + + self.assertTrue(C._meta.database is None) + self.assertTrue(A._meta.database is None) + def test_batch_commit(self): class PatchCommitDatabase(SqliteDatabase): commits = 0 diff --git a/tests/fields.py b/tests/fields.py index 6f7e899e6..0817f23a6 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -243,7 +243,7 @@ def test_extract_parts(self): .tuples()) row, = query - if IS_SQLITE or IS_MYSQL or IS_CRDB: + if IS_SQLITE or IS_MYSQL: self.assertEqual(row, (2011, 1, 2, 11, 12, 13, 2012, 2, 3, 3, 13, 37)) else: @@ -561,6 +561,31 @@ def test_ip_field(self): class TestBitFields(ModelTestCase): requires = [Bits] + def test_bit_field_update(self): + def assertFlags(expected): + query = Bits.select().order_by(Bits.id) + self.assertEqual([b.flags for b in query], expected) + + # Bits - flags (1=sticky, 2=favorite, 4=minimized) + for i in range(1, 5): + Bits.create(flags=i) + + Bits.update(flags=Bits.flags & ~2).execute() + assertFlags([1, 0, 1, 4]) + + Bits.update(flags=Bits.flags | 2).execute() + assertFlags([3, 2, 3, 6]) + + Bits.update(flags=Bits.is_favorite.clear()).execute() + assertFlags([1, 0, 1, 4]) + + Bits.update(flags=Bits.is_favorite.set()).execute() + assertFlags([3, 2, 3, 6]) + + # Clear multiple bits in one operation. + Bits.update(flags=Bits.flags & ~(1 | 4)).execute() + assertFlags([2, 2, 2, 2]) + def test_bit_field_auto_flag(self): class Bits2(TestModel): flags = BitField() @@ -620,6 +645,13 @@ def test_bit_field(self): query = Bits.select().where((Bits.flags & 1) == 1).order_by(Bits.id) self.assertEqual([x.id for x in query], [b1.id, b3.id]) + # Test combining multiple bit expressions. + query = Bits.select().where(Bits.is_sticky & Bits.is_favorite) + self.assertEqual([x.id for x in query], [b3.id]) + + query = Bits.select().where(Bits.is_sticky & ~Bits.is_favorite) + self.assertEqual([x.id for x in query], [b1.id]) + def test_bigbit_field_instance_data(self): b = Bits() values_to_set = (1, 11, 63, 31, 55, 48, 100, 99) diff --git a/tests/model_save.py b/tests/model_save.py index f0841b825..8b7f918fb 100644 --- a/tests/model_save.py +++ b/tests/model_save.py @@ -163,3 +163,10 @@ def test_save_no_data2(self): def test_save_no_data3(self): t5 = T5.create() self.assertRaises(ValueError, t5.save) + + def test_save_only_no_data(self): + t5 = T5.create(val=1) + t5.val = 2 + self.assertRaises(ValueError, t5.save, only=[]) + t5_db = T5.get(T5.id == t5.id) + self.assertEqual(t5_db.val, 1) diff --git a/tests/model_sql.py b/tests/model_sql.py index b2dfe5c10..e8b35f37e 100644 --- a/tests/model_sql.py +++ b/tests/model_sql.py @@ -206,6 +206,18 @@ def test_filter_simple(self): 'WHERE ((("t1"."id" >= ?) AND ("t1"."id" < ?)) AND ' '("t1"."username" = ?))'), [1, 5, 'huey']) + query = User.filter(~DQ(id=1), username__in=('foo', 'bar')) + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."username" FROM "users" AS "t1" ' + 'WHERE (NOT ("t1"."id" = ?) AND ("t1"."username" IN (?, ?)))'), + [1, 'foo', 'bar']) + + query = User.filter((DQ(id=1) | DQ(id=2)), username__in=('foo', 'bar')) + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."username" FROM "users" AS "t1" ' + 'WHERE ((("t1"."id" = ?) OR ("t1"."id" = ?)) AND ' + '("t1"."username" IN (?, ?)))'), [1, 2, 'foo', 'bar']) + def test_filter_expressions(self): query = User.filter( DQ(username__in=['huey', 'zaizee']) | diff --git a/tests/models.py b/tests/models.py index ba49ec948..4b694470a 100644 --- a/tests/models.py +++ b/tests/models.py @@ -631,9 +631,10 @@ def test_populate_unsaved_relations(self): self.assertTrue(user.save()) self.assertTrue(user.id is not None) - self.assertTrue(tweet.user_id is None) - self.assertTrue(tweet.save()) - self.assertEqual(tweet.user_id, user.id) + with self.assertQueryCount(1): + self.assertEqual(tweet.user_id, user.id) + self.assertTrue(tweet.save()) + self.assertEqual(tweet.user_id, user.id) tweet_db = Tweet.get(Tweet.content == 'foo') self.assertEqual(tweet_db.user.username, 'charlie') diff --git a/tests/postgres.py b/tests/postgres.py index 3fef03e69..09cc94a2d 100644 --- a/tests/postgres.py +++ b/tests/postgres.py @@ -936,3 +936,26 @@ def test_lateral_top_n(self): ('b', 'b4'), ('b', 'b7'), ('c', None)]) + + @requires_models(User, Tweet) + def test_lateral_helper(self): + self.create_data() + + subq = (Tweet + .select(Tweet.content, Tweet.timestamp) + .where(Tweet.user == User.id) + .order_by(Tweet.timestamp.desc()) + .limit(2) + .lateral()) + + query = (User + .select(User, subq.c.content) + .join(subq, on=True) + .order_by(subq.c.timestamp.desc(nulls='last'))) + with self.assertQueryCount(1): + results = [(u.username, u.tweet.content) for u in query] + self.assertEqual(results, [ + ('a', 'a10'), + ('b', 'b7'), + ('b', 'b4'), + ('a', 'a2')]) diff --git a/tests/regressions.py b/tests/regressions.py index ce9621521..cc73c531e 100644 --- a/tests/regressions.py +++ b/tests/regressions.py @@ -5,6 +5,8 @@ from peewee import * from playhouse.hybrid import * +from playhouse.migrate import migrate +from playhouse.migrate import SchemaMigrator from .base import BaseTestCase from .base import IS_MYSQL @@ -1206,8 +1208,7 @@ class BCTweet(TestModel): class TestBulkCreateWithFK(ModelTestCase): - requires = [BCUser, BCTweet] - + @requires_models(BCUser, BCTweet) def test_bulk_create_with_fk(self): u1 = BCUser.create(username='u1') u2 = BCUser.create(username='u2') @@ -1219,6 +1220,35 @@ def test_bulk_create_with_fk(self): self.assertEqual(BCTweet.select().where(BCTweet.user == 'u1').count(), 4) self.assertEqual(BCTweet.select().where(BCTweet.user != 'u1').count(), 0) + u = BCUser(username='u3') + t = BCTweet(user=u, content='tx') + with self.assertQueryCount(2): + BCUser.bulk_create([u]) + BCTweet.bulk_create([t]) + + with self.assertQueryCount(1): + t_db = (BCTweet + .select(BCTweet, BCUser) + .join(BCUser) + .where(BCUser.username == 'u3') + .get()) + self.assertEqual(t_db.content, 'tx') + self.assertEqual(t_db.user.username, 'u3') + + @requires_postgresql + @requires_models(User, Tweet) + def test_bulk_create_related_objects(self): + u = User(username='u1') + t = Tweet(user=u, content='t1') + with self.assertQueryCount(2): + User.bulk_create([u]) + Tweet.bulk_create([t]) + + with self.assertQueryCount(1): + t_db = Tweet.select(Tweet, User).join(User).get() + self.assertEqual(t_db.content, 't1') + self.assertEqual(t_db.user.username, 'u1') + class UUIDReg(TestModel): id = UUIDField(primary_key=True, default=uuid.uuid4) @@ -1239,3 +1269,170 @@ def test_bulk_update_uuid_pk(self): r1_db, r2_db = UUIDReg.select().order_by(UUIDReg.key) self.assertEqual(r1_db.key, 'k1-x') self.assertEqual(r2_db.key, 'k2-x') + + +class TestSaveClearingPK(ModelTestCase): + requires = [User, Tweet] + + def test_save_clear_pk(self): + u = User.create(username='u1') + t1 = Tweet.create(content='t1', user=u) + orig_id, t1.id = t1.id, None + t1.content = 't2' + t1.save() + self.assertTrue(t1.id is not None) + self.assertTrue(t1.id != orig_id) + tweets = [t.content for t in u.tweets.order_by(Tweet.id)] + self.assertEqual(tweets, ['t1', 't2']) + + +class Bits(TestModel): + b1 = BitField(default=1) + b1_1 = b1.flag(1) + b1_2 = b1.flag(2) + + b2 = BitField(default=0) + b2_1 = b2.flag() + b2_2 = b2.flag() + + +class TestBitFieldName(ModelTestCase): + requires = [Bits] + + def assertBits(self, bf, expected): + b1_1, b1_2, b2_1, b2_2 = expected + self.assertEqual(bf.b1_1, b1_1) + self.assertEqual(bf.b1_2, b1_2) + self.assertEqual(bf.b2_1, b2_1) + self.assertEqual(bf.b2_2, b2_2) + + def test_bit_field_name(self): + bf = Bits.create() + self.assertBits(bf, (True, False, False, False)) + + bf.b1_1 = False + bf.b1_2 = True + bf.b2_1 = True + bf.save() + self.assertBits(bf, (False, True, True, False)) + + bf = Bits.get(Bits.id == bf.id) + self.assertBits(bf, (False, True, True, False)) + + self.assertEqual(bf.b1, 2) + self.assertEqual(bf.b2, 1) + + self.assertEqual(Bits.select().where(Bits.b1_2).count(), 1) + self.assertEqual(Bits.select().where(Bits.b2_2).count(), 0) + + +class FKMA(TestModel): + name = TextField() + +class FKMB(TestModel): + name = TextField() + fkma = ForeignKeyField(FKMA, backref='fkmb_set', null=True) + + +class TestFKMigrationRegression(ModelTestCase): + requires = [FKMA, FKMB] + + def test_fk_migration(self): + migrator = SchemaMigrator.from_database(self.database) + migrate(migrator.drop_column( + FKMB._meta.table_name, + FKMB.fkma.column_name)) + + migrate(migrator.add_column( + FKMB._meta.table_name, + FKMB.fkma.column_name, + FKMB.fkma)) + + fa = FKMA.create(name='fa') + FKMB.create(name='fb', fkma=fa) + obj = FKMB.select().first() + self.assertEqual(obj.name, 'fb') + + +class ModelTypeField(CharField): + def db_value(self, value): + if value is not None: + return value._meta.name + def python_value(self, value): + if value is not None: + return {'user': User, 'tweet': Tweet}[value] + + +class MTF(TestModel): + name = TextField() + mtype = ModelTypeField() + + +class TestFieldValueRegression(ModelTestCase): + requires = [MTF] + + def test_field_value_regression(self): + u = MTF.create(name='user', mtype=User) + u_db = MTF.get() + + self.assertEqual(u_db.name, 'user') + self.assertTrue(u_db.mtype is User) + + +class NLM(TestModel): + a = IntegerField() + b = IntegerField() + +class TestRegressionNodeListClone(ModelTestCase): + requires = [NLM] + + def test_node_list_clone_expr(self): + expr = (NLM.a + NLM.b) + query = NLM.select(expr.alias('expr')).order_by(expr).distinct(expr) + self.assertSQL(query, ( + 'SELECT DISTINCT ON ("t1"."a" + "t1"."b") ' + '("t1"."a" + "t1"."b") AS "expr" ' + 'FROM "nlm" AS "t1" ' + 'ORDER BY ("t1"."a" + "t1"."b")'), []) + + +class LK(TestModel): + key = TextField() + +class TestLikeEscape(ModelTestCase): + requires = [LK] + + def assertNames(self, expr, expected): + query = LK.select().where(expr).order_by(LK.id) + self.assertEqual([lk.key for lk in query], expected) + + def test_like_escape(self): + names = ('foo', 'foo%', 'foo%bar', 'foo_bar', 'fooxba', 'fooba') + LK.insert_many([(n,) for n in names]).execute() + + cases = ( + (LK.key.contains('bar'), ['foo%bar', 'foo_bar']), + (LK.key.contains('%'), ['foo%', 'foo%bar']), + (LK.key.contains('_'), ['foo_bar']), + (LK.key.contains('o%b'), ['foo%bar']), + (LK.key.startswith('foo%'), ['foo%', 'foo%bar']), + (LK.key.startswith('foo_'), ['foo_bar']), + (LK.key.startswith('bar'), []), + (LK.key.endswith('ba'), ['fooxba', 'fooba']), + (LK.key.endswith('_bar'), ['foo_bar']), + (LK.key.endswith('fo'), []), + ) + for expr, expected in cases: + self.assertNames(expr, expected) + + def test_like_escape_backslash(self): + names = ('foo_bar\\baz', 'bar\\', 'fbar\\baz', 'foo_bar') + LK.insert_many([(n,) for n in names]).execute() + + cases = ( + (LK.key.contains('\\'), ['foo_bar\\baz', 'bar\\', 'fbar\\baz']), + (LK.key.contains('_bar\\'), ['foo_bar\\baz']), + (LK.key.contains('bar\\'), ['foo_bar\\baz', 'bar\\', 'fbar\\baz']), + ) + for expr, expected in cases: + self.assertNames(expr, expected) diff --git a/tests/results.py b/tests/results.py index 3059c4969..fb1618845 100644 --- a/tests/results.py +++ b/tests/results.py @@ -1,3 +1,5 @@ +import datetime + from peewee import * from .base import get_in_memory_db @@ -171,3 +173,38 @@ def test_dict_flattening(self): (1, 't1', 'u1'), (2, 't2', 'u1'), (3, 't3', 'u1')]) + + +class Reg(TestModel): + key = TextField() + ts = DateTimeField() + +class TestSpecifyConverter(ModelTestCase): + requires = [Reg] + + def test_specify_converter(self): + D = lambda d: datetime.datetime(2020, 1, d) + for i in range(1, 4): + Reg.create(key='k%s' % i, ts=D(i)) + + RA = Reg.alias() + subq = RA.select(RA.key, RA.ts, RA.ts.alias('aliased')) + + ra_a = subq.c.aliased.alias('aliased') + q = (Reg + .select(Reg.key, subq.c.ts.alias('ts'), + ra_a.converter(Reg.ts.python_value)) + .join(subq, on=(Reg.key == subq.c.key).alias('rsub')) + .order_by(Reg.key)) + results = [(r.key, r.ts, r.aliased) for r in q.objects()] + self.assertEqual(results, [ + ('k1', D(1), D(1)), + ('k2', D(2), D(2)), + ('k3', D(3), D(3))]) + + results2 = [(r.key, r.rsub.ts, r.rsub.aliased) + for r in q] + self.assertEqual(results, [ + ('k1', D(1), D(1)), + ('k2', D(2), D(2)), + ('k3', D(3), D(3))]) diff --git a/tests/schema.py b/tests/schema.py index c9511d403..58bca8e30 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -174,6 +174,13 @@ class Meta: ('CREATE INDEX "event_timestamp" ON "event" ' 'USING BRIN ("timestamp")', [])]) + # Check that we support MySQL-style USING clause. + idx, = Event._meta.fields_to_index() + self.assertSQL(idx, ( + 'CREATE INDEX IF NOT EXISTS "event_timestamp" ' + 'USING BRIN ON "event" ("timestamp")'), [], + index_using_precedes_table=True) + def test_model_indexes_custom_tablename(self): class KV(TestModel): key = TextField() diff --git a/tests/shortcuts.py b/tests/shortcuts.py index a7f795ee7..b36ad0dae 100644 --- a/tests/shortcuts.py +++ b/tests/shortcuts.py @@ -622,3 +622,42 @@ def test_reconnect_mixin(self): curs = self.database.execute_sql(sql) self.assertEqual(curs.fetchone(), (2,)) self.assertEqual(self.database._close_counter, 1) + + +class MMA(TestModel): + key = TextField() + value = IntegerField() + +class MMB(TestModel): + key = TextField() + +class MMC(TestModel): + key = TextField() + value = IntegerField() + misc = TextField(null=True) + + +class TestResolveMultiModelQuery(ModelTestCase): + requires = [MMA, MMB, MMC] + + def test_resolve_multimodel_query(self): + MMA.insert_many([('k0', 0), ('k1', 1)]).execute() + MMB.insert_many([('k10',), ('k11',)]).execute() + MMC.insert_many([('k20', 20, 'a'), ('k21', 21, 'b')]).execute() + + mma = MMA.select(MMA.key, MMA.value) + mmb = MMB.select(MMB.key, Value(99).alias('value')) + mmc = MMC.select(MMC.key, MMC.value) + query = (mma | mmb | mmc).order_by(SQL('1')) + data = [obj for obj in resolve_multimodel_query(query)] + + expected = [ + MMA(key='k0', value=0), MMA(key='k1', value=1), + MMB(key='k10', value=99), MMB(key='k11', value=99), + MMC(key='k20', value=20), MMC(key='k21', value=21)] + self.assertEqual(len(data), len(expected)) + + for row, exp_row in zip(data, expected): + self.assertEqual(row.__class__, exp_row.__class__) + self.assertEqual(row.key, exp_row.key) + self.assertEqual(row.value, exp_row.value) diff --git a/tests/sql.py b/tests/sql.py index 0286ee338..43f186473 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -71,6 +71,14 @@ def test_select_in_list_of_values(self): 'WHERE ("t1"."name" IN (?, ?))')) self.assertEqual(sorted(params), ['charlie', 'huey']) + query = (Person + .select() + .where(Person.id.in_(range(1, 10, 2)))) + self.assertSQL(query, ( + 'SELECT "t1"."id", "t1"."name", "t1"."dob" ' + 'FROM "person" AS "t1" ' + 'WHERE ("t1"."id" IN (?, ?, ?, ?, ?))'), [1, 3, 5, 7, 9]) + def test_select_subselect_function(self): # For functions whose only argument is a subquery, we do not need to # include additional parentheses -- in fact, some databases will report @@ -797,6 +805,36 @@ def test_where_convert_to_is_null(self): 'SELECT "t1"."id", "t1"."content", "t1"."user_id" ' 'FROM "notes" AS "t1" WHERE ("t1"."user_id" IS ?)'), [None]) + def test_like_escape(self): + T = Table('tbl', ('key',)) + def assertLike(expr, expected): + query = T.select().where(expr) + sql, params = __sql__(T.select().where(expr)) + match_obj = re.search(r'\("t1"."key" (ILIKE[^\)]+)\)', sql) + if match_obj is None: + raise AssertionError('LIKE expression not found in query.') + like, = match_obj.groups() + self.assertEqual((like, params), expected) + + cases = ( + (T.key.contains('base'), ('ILIKE ?', ['%base%'])), + (T.key.contains('x_y'), ("ILIKE ? ESCAPE ?", ['%x\\_y%', '\\'])), + (T.key.contains('__y'), ("ILIKE ? ESCAPE ?", ['%\\_\\_y%', '\\'])), + (T.key.contains('%'), ("ILIKE ? ESCAPE ?", ['%\\%%', '\\'])), + (T.key.contains('_%'), ("ILIKE ? ESCAPE ?", ['%\\_\\%%', '\\'])), + (T.key.startswith('base'), ("ILIKE ?", ['base%'])), + (T.key.startswith('x_y'), ("ILIKE ? ESCAPE ?", ['x\\_y%', '\\'])), + (T.key.startswith('x%'), ("ILIKE ? ESCAPE ?", ['x\\%%', '\\'])), + (T.key.startswith('_%'), ("ILIKE ? ESCAPE ?", ['\\_\\%%', '\\'])), + (T.key.endswith('base'), ("ILIKE ?", ['%base'])), + (T.key.endswith('x_y'), ("ILIKE ? ESCAPE ?", ['%x\\_y', '\\'])), + (T.key.endswith('x%'), ("ILIKE ? ESCAPE ?", ['%x\\%', '\\'])), + (T.key.endswith('_%'), ("ILIKE ? ESCAPE ?", ['%\\_\\%', '\\'])), + ) + + for expr, expected in cases: + assertLike(expr, expected) + class TestInsertQuery(BaseTestCase): def test_insert_simple(self): @@ -1729,6 +1767,17 @@ def test_parentheses(self): 'WHERE ("t2"."username" = "t1"."name"))'), []) +class TestExpressionSQL(BaseTestCase): + def test_parentheses_functions(self): + expr = (User.c.income + 100) + expr2 = expr * expr + query = User.select(fn.sum(expr), fn.avg(expr2)) + self.assertSQL(query, ( + 'SELECT sum("t1"."income" + ?), ' + 'avg(("t1"."income" + ?) * ("t1"."income" + ?)) ' + 'FROM "users" AS "t1"'), [100, 100, 100]) + + #Person = Table('person', ['id', 'name', 'dob']) class TestOnConflictSqlite(BaseTestCase): diff --git a/tests/sqliteq.py b/tests/sqliteq.py index 1c5bed513..fb67a6cfa 100644 --- a/tests/sqliteq.py +++ b/tests/sqliteq.py @@ -80,15 +80,17 @@ def test_query_execution(self): self.database.start() - users = list(qr) - huey = User.create(name='huey') - mickey = User.create(name='mickey') + try: + users = list(qr) + huey = User.create(name='huey') + mickey = User.create(name='mickey') - self.assertTrue(huey.id is not None) - self.assertTrue(mickey.id is not None) - self.assertEqual(self.database.queue_size(), 0) + self.assertTrue(huey.id is not None) + self.assertTrue(mickey.id is not None) + self.assertEqual(self.database.queue_size(), 0) - self.database.stop() + finally: + self.database.stop() def create_thread(self, fn, *args): raise NotImplementedError diff --git a/tests/transactions.py b/tests/transactions.py index 7c601b785..8ff012908 100644 --- a/tests/transactions.py +++ b/tests/transactions.py @@ -2,6 +2,7 @@ from .base import DatabaseTestCase from .base import IS_CRDB +from .base import IS_CRDB_NESTED_TX from .base import IS_SQLITE from .base import ModelTestCase from .base import db @@ -23,7 +24,8 @@ def _save(self, *vals): def requires_nested(fn): - return skip_if(IS_CRDB, 'nested transaction support is required')(fn) + return skip_if(IS_CRDB and not IS_CRDB_NESTED_TX, + 'nested transaction support is required')(fn) class TestTransaction(BaseTransactionTestCase): @@ -280,6 +282,15 @@ def test_session(self): self.assertTrue(db.session_rollback()) self.assertRegister([1]) + def test_session_with_closed_db(self): + db.close() + self.assertTrue(db.session_start()) + self.assertFalse(db.is_closed()) + self.assertRaises(OperationalError, db.close) + self._save(1) + self.assertTrue(db.session_rollback()) + self.assertRegister([]) + def test_session_inside_context_manager(self): with db.atomic(): self.assertTrue(db.session_start())