diff --git a/docs/conf.py b/docs/conf.py index 4bf7ca7e0..3fd57b360 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,6 +11,8 @@ # All configuration values have a default; values that are commented out # serve to show the default. +RTD_NEW_THEME = True + import sys, os # If extensions (or modules to document with autodoc) are in another directory, @@ -48,9 +50,9 @@ # built documents. # # The short X.Y version. -version = '2.1.5' +version = '2.1.6' # The full version, including alpha/beta/rc tags. -release = '2.1.5' +release = '2.1.6' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/index.rst b/docs/index.rst index ebababf78..d3a4acacf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -45,6 +45,14 @@ Contents: peewee/playhouse peewee/upgrading +Note +---- + +Hi, I'm Charlie the author of peewee. I hope that you enjoy using this library +as much as I've enjoyed writing it. Peewee wouldn't be anywhere near as useful +without people like you, so thank you. If you find any bugs, odd behavior, or +have an idea for a new feature please don't hesitate to `open an issue `_ on GitHub. + Indices and tables ================== diff --git a/docs/peewee/api.rst b/docs/peewee/api.rst index 2339890ae..e13b206da 100644 --- a/docs/peewee/api.rst +++ b/docs/peewee/api.rst @@ -606,6 +606,23 @@ Fields related to. +.. py:class:: CompositeKey(*fields) + + Specify a composite primary key for a model. Unlike the other fields, a + composite key is defined in the model's ``Meta`` class after the fields + have been defined. It takes as parameters the string names of the fields + to use as the primary key: + + .. code-block:: python + + class BlogTagThrough(Model): + blog = ForeignKeyField(Blog, related_name='tags') + tag = ForeignKeyField(Tag, related_name='blogs') + + class Meta: + primary_key = CompositeKey('blog', 'tag') + + .. _query-types: Query Types @@ -1496,6 +1513,23 @@ Database and its subclasses :param Field date_field: field instance storing a datetime, date or time. :rtype: an expression object. + .. py:method:: sql_error_handler(exception, sql, params, require_commit) + + This hook is called when an error is raised executing a query, allowing + your application to inject custom error handling behavior. The default + implementation simply reraises the exception. + + .. code-block:: python + + class SqliteDatabaseCustom(SqliteDatabase): + def sql_error_handler(self, exception, sql, params, require_commit): + # Perform some custom behavior, for example close the + # connection to the database. + self.close() + + # Re-raise the exception. + raise exception + .. py:class:: SqliteDatabase(Database) diff --git a/docs/peewee/cookbook.rst b/docs/peewee/cookbook.rst index d855363a0..10bc59ea1 100644 --- a/docs/peewee/cookbook.rst +++ b/docs/peewee/cookbook.rst @@ -178,9 +178,11 @@ instantiate your database with ``threadlocals=True`` (*recommended*): concurrent_db = SqliteDatabase('stats.db', threadlocals=True) -The above implementation stores connection state in a thread local and will only -use that connection for a given thread. Pysqlite can share a connection across -threads, so if you would prefer to reuse a connection in multiple threads: +With the above peewee stores connection state in a thread local; each thread gets its +own separate connection. + +Alternatively, Python sqlite3 module can share a connection across different threads, +but you have to disable runtime checks to reuse the single connection: .. code-block:: python diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index a65446315..399b17ba7 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -16,6 +16,7 @@ specific functionality: Modules which expose higher-level python constructs: +* :ref:`djpeewee` * :ref:`gfk` * :ref:`kv` * :ref:`proxy` @@ -26,6 +27,7 @@ As well as tools for working with databases: * :ref:`pwiz` * :ref:`migrate` +* :ref:`csv_loader` * :ref:`test_utils` @@ -764,6 +766,115 @@ sqlite_ext API notes The above code is exactly what the :py:meth:`match` function provides. +.. _djpeewee: + +Django Integration +------------------ + +The Django ORM provides a very high-level abstraction over SQL and as a consequence is in some ways +`limited in terms of flexibility or expressiveness `_. I +wrote a `blog post `_ +describing my search for a "missing link" between Django's ORM and the SQL it +generates, concluding that no such layer exists. The ``djpeewee`` module attempts +to provide an easy-to-use, structured layer for generating SQL queries for use +with Django's ORM. + +A couple use-cases might be: + +* Joining on fields that are not related by foreign key (for example UUID fields). +* Performing aggregate queries on calculated values. +* Features that Django does not support such as ``CASE`` statements. +* Utilizing SQL functions that Django does not support, such as ``SUBSTR``. +* Replacing nearly-identical SQL queries with reusable, composable data-structures. + +Below is an example of how you might use this: + +.. code-block:: python + + # Django model. + class Event(models.Model): + start_time = models.DateTimeField() + end_time = models.DateTimeField() + title = models.CharField(max_length=255) + + # Suppose we want to find all events that are longer than an hour. Django + # does not support this, but we can use peewee. + from playhouse.djpeewee import translate + P = translate(Event) + query = (P.Event + .select() + .where( + (P.Event.end_time - P.Event.start_time) > timedelta(hours=1))) + + # Now feed our peewee query into Django's `raw()` method: + sql, params = query.sql() + Event.objects.raw(sql, params) + +Foreign keys and Many-to-many relationships +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :py:func:`translate` function will recursively traverse the graph of models +and return a dictionary populated with everything it finds. Back-references are +not searched by default, but can be included by specifying ``backrefs=True``. + +Example: + +.. code-block:: pycon + + >>> from django.contrib.auth.models import User, Group + >>> from playhouse.djpeewee import translate + >>> translate(User, Group) + {'ContentType': peewee.ContentType, + 'Group': peewee.Group, + 'Group_permissions': peewee.Group_permissions, + 'Permission': peewee.Permission, + 'User': peewee.User, + 'User_groups': peewee.User_groups, + 'User_user_permissions': peewee.User_user_permissions} + +As you can see in the example above, although only `User` and `Group` were passed +in to :py:func:`translate`, several other models which are related by foreign key +were also created. Additionally, the many-to-many "through" tables were created +as separate models since peewee does not abstract away these types of relationships. + +Using the above models it is possible to construct joins. The following example +will get all users who belong to a group that starts with the letter "A": + +.. code-block:: pycon + + >>> P = translate(User, Group) + >>> query = P.User.select().join(P.User_groups).join(P.Group).where( + ... fn.Lower(fn.Substr(P.Group.name, 1, 1)) == 'a') + >>> sql, params = query.sql() + >>> print sql # formatted for legibility + SELECT t1."id", t1."password", ... + FROM "auth_user" AS t1 + INNER JOIN "auth_user_groups" AS t2 ON (t1."id" = t2."user_id") + INNER JOIN "auth_group" AS t3 ON (t2."group_id" = t3."id") + WHERE (Lower(Substr(t3."name", %s, %s)) = %s) + +djpeewee API +^^^^^^^^^^^^ + +.. py:function:: translate(*models, **options) + + Translate the given Django models into roughly equivalent peewee models + suitable for use constructing queries. Foreign keys and many-to-many relationships + will be followed and models generated, although back references are not traversed. + + :param models: One or more Django model classes. + :param options: A dictionary of options, see note below. + :returns: A dict-like object containing the generated models, but which supports + dotted-name style lookups. + + The following are valid options: + + * ``recurse``: Follow foreign keys and many to many (default: ``True``). + * ``max_depth``: Maximum depth to recurse (default: ``None``, unlimited). + * ``backrefs``: Follow backrefs (default: ``False``). + * ``exclude``: A list of models to exclude. + + .. _gfk: Generic foreign keys @@ -1310,6 +1421,87 @@ Renaming a table migrator.rename_table(Story, 'stories') +.. _csv_loader: + +CSV Loader +---------- + +This module contains helpers for loading CSV data into a database. CSV files can +be introspected to generate an appropriate model class for working with the data. +This makes it really easy to explore the data in a CSV file using Peewee and SQL. + +Here is how you would load a CSV file into an in-memory SQLite database. The +call to :py:func:`load_csv` returns a :py:class:`Model` instance suitable for +working with the CSV data: + +.. code-block:: python + + from playhouse.csv_loader import load_csv + db = SqliteDatabase(':memory:') + ZipToTZ = load_csv(db, 'zip_to_tz.csv') + +Now we can run queries using the new model. + +.. code-block:: pycon + + # Get the timezone for a zipcode. + >>> ZipToTZ.get(ZipToTZ.zip == 66047).timezone + 'US/Central' + + # Get all the zipcodes for my town. + >>> [row.zip for row in ZipToTZ.select().where( + ... (ZipToTZ.city == 'Lawrence') && (ZipToTZ.state == 'KS'))] + [66044, 66045, 66046, 66047, 66049] + +For more information and examples check out this `blog post `_. + + +CSV Loader API +^^^^^^^^^^^^^^ + +.. py:function:: load_csv(db_or_model, filename[, fields=None[, field_names=None[, has_header=True[, sample_size=10[, converter=None[, db_table=None[, **reader_kwargs]]]]]]]) + + Load a CSV file into the provided database or model class, returning a + :py:class:`Model` suitable for working with the CSV data. + + :param db_or_model: Either a :py:class:`Database` instance or a :py:class:`Model` class. If a model is not provided, one will be automatically generated for you. + :param str filename: Path of CSV file to load. + :param list fields: A list of :py:class:`Field` instances mapping to each column in the CSV. This allows you to manually specify the column types. If not provided, and a model is not provided, the field types will be determined automatically. + :param list field_names: A list of strings to use as field names for each column in the CSV. If not provided, and a model is not provided, the field names will be determined by looking at the header row of the file. If no header exists, then the fields will be given generic names. + :param bool has_header: Whether the first row is a header. + :param int sample_size: Number of rows to look at when introspecting data types. If set to ``0``, then a generic field type will be used for all fields. + :param RowConverter converter: a :py:class:`RowConverter` instance to use for introspecting the CSV. If not provided, one will be created. + :param str db_table: The name of the database table to load data into. If this value is not provided, it will be determined using the filename of the CSV file. If a model is provided, this value is ignored. + :param reader_kwargs: Arbitrary keyword arguments to pass to the ``csv.reader`` object, such as the dialect, separator, etc. + :rtype: A :py:class:`Model` suitable for querying the CSV data. + + Basic example -- field names and types will be introspected: + + .. code-block:: python + + from playhouse.csv_loader import * + db = SqliteDatabase(':memory:') + User = load_csv(db, 'users.csv') + + Using a pre-defined model: + + .. code-block:: python + + class ZipToTZ(Model): + zip = IntegerField() + timezone = CharField() + + load_csv(ZipToTZ, 'zip_to_tz.csv') + + Specifying fields: + + .. code-block:: python + + fields = [DecimalField(), IntegerField(), IntegerField(), DateField()] + field_names = ['amount', 'from_acct', 'to_acct', 'timestamp'] + Payments = load_csv(db, 'payments.csv', fields=fields, field_names=field_names, has_header=False) + + .. _test_utils: Test Utils diff --git a/peewee.py b/peewee.py index 7f6859b5f..2386dabdd 100644 --- a/peewee.py +++ b/peewee.py @@ -7,7 +7,6 @@ # ///' # // # ' -from __future__ import with_statement import datetime import decimal import logging @@ -21,6 +20,7 @@ from inspect import isclass __all__ = [ + 'BareField', 'BigIntegerField', 'BlobField', 'BooleanField', @@ -123,9 +123,7 @@ def _sqlite_date_part(lookup_type, datetime_string): return getattr(dt, lookup_type) if psycopg2: - import psycopg2.extensions - psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) - psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY) + from psycopg2 import extensions as pg_extensions # Peewee logger = logging.getLogger('peewee') @@ -433,6 +431,10 @@ def python_value(self, value): def __hash__(self): return hash(self.name + '.' + self.model_class.__name__) +class BareField(Field): + db_field = 'bare' + template = '' + class IntegerField(Field): db_field = 'int' coerce = int @@ -632,7 +634,10 @@ def __set__(self, instance, value): instance._data[self.att_name] = value.get_id() instance._obj_cache[self.att_name] = value else: + orig_value = instance._data.get(self.att_name) instance._data[self.att_name] = value + if orig_value != value and self.att_name in instance._obj_cache: + del instance._obj_cache[self.att_name] class ReverseRelationDescriptor(object): def __init__(self, field): @@ -736,6 +741,7 @@ def __set__(self, instance, value): class QueryCompiler(object): field_map = { + 'bare': '', 'bigint': 'BIGINT', 'blob': 'BLOB', 'bool': 'SMALLINT', @@ -1651,9 +1657,10 @@ def count(self): # defaults to a count() of the primary key return self.aggregate(convert=False) or 0 - def wrapped_count(self): + def wrapped_count(self, clear_limit=True): clone = self.order_by() - clone._limit = clone._offset = None + if clear_limit: + clone._limit = clone._offset = None sql, params = clone.sql() wrapped = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql @@ -2005,10 +2012,16 @@ class PostgresqlDatabase(Database): reserved_tables = ['user'] sequences = True + register_unicode = True + def _connect(self, database, **kwargs): if not psycopg2: raise ImproperlyConfigured('psycopg2 must be installed.') - return psycopg2.connect(database=database, **kwargs) + conn = psycopg2.connect(database=database, **kwargs) + if self.register_unicode: + pg_extensions.register_type(pg_extensions.UNICODE, conn) + pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn) + return conn def last_insert_id(self, cursor, model): seq = model._meta.primary_key.sequence diff --git a/playhouse/csv_loader.py b/playhouse/csv_loader.py new file mode 100644 index 000000000..ab0ce4f09 --- /dev/null +++ b/playhouse/csv_loader.py @@ -0,0 +1,264 @@ +""" +Peewee helper for loading CSV data into a database. + +Load the users CSV file into the database and return a Model for accessing +the data: + + from playhouse.csv_loader import load_csv + db = SqliteDatabase(':memory:') + User = load_csv(db, 'users.csv') + +Provide explicit field types and/or field names: + + fields = [IntegerField(), IntegerField(), DateTimeField(), DecimalField()] + field_names = ['from_acct', 'to_acct', 'timestamp', 'amount'] + Payments = load_csv(db, 'payments.csv', fields, field_names) +""" +import csv +import datetime +import os +import re +from collections import OrderedDict +from contextlib import contextmanager + +from peewee import * +from peewee import Database + + +class _CSVReader(object): + @contextmanager + def get_reader(self, filename, **reader_kwargs): + fh = open(filename, 'r') + reader = csv.reader(fh, **reader_kwargs) + yield reader + fh.close() + +class RowConverter(_CSVReader): + """ + Simple introspection utility to convert a CSV file into a list of headers + and column types. + + :param database: a peewee Database object. + :param bool has_header: whether the first row of CSV is a header row. + :param int sample_size: number of rows to introspect + """ + date_formats = [ + '%Y-%m-%d', + '%m/%d/%Y'] + + datetime_formats = [ + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f'] + + def __init__(self, database, has_header=True, sample_size=10): + self.database = database + self.has_header = has_header + self.sample_size = sample_size + + def matches_date(self, value, formats): + for fmt in formats: + try: + datetime.datetime.strptime(value, fmt) + except ValueError: + pass + else: + return True + + def field(field_class, **field_kwargs): + def decorator(fn): + fn.field = lambda: field_class(**field_kwargs) + return fn + return decorator + + @field(IntegerField, default=0) + def is_integer(self, value): + return value.isdigit() + + @field(FloatField, default=0) + def is_float(self, value): + try: + float(value) + except (ValueError, TypeError): + pass + else: + return True + + @field(DateTimeField, null=True) + def is_datetime(self, value): + return self.matches_date(value, self.datetime_formats) + + @field(DateField, null=True) + def is_date(self, value): + return self.matches_date(value, self.date_formats) + + @field(BareField, default='') + def default(self, value): + return True + + def extract_rows(self, filename, **reader_kwargs): + """ + Extract `self.sample_size` rows from the CSV file and analyze their + data-types. + + :param str filename: A string filename. + :param reader_kwargs: Arbitrary parameters to pass to the CSV reader. + :returns: A 2-tuple containing a list of headers and list of rows + read from the CSV file. + """ + rows = [] + rows_to_read = self.sample_size + with self.get_reader(filename, **reader_kwargs) as reader: + if self.has_header: + rows_to_read += 1 + for i, row in enumerate(reader): + rows.append(row) + if i == self.sample_size: + break + if self.has_header: + header, rows = rows[0], rows[1:] + else: + header = ['field_%d' % i for i in range(len(rows[0]))] + return header, rows + + def get_checks(self): + """Return a list of functions to use when testing values.""" + return [ + self.is_date, + self.is_datetime, + self.is_integer, + self.is_float, + self.default] + + def analyze(self, rows): + """ + Analyze the given rows and try to determine the type of value stored. + + :param list rows: A list-of-lists containing one or more rows from a + csv file. + :returns: A list of peewee Field objects for each column in the CSV. + """ + transposed = zip(*rows) + checks = self.get_checks() + column_types = [] + for i, column in enumerate(transposed): + # Remove any empty values. + col_vals = [val for val in column if val != ''] + for check in checks: + results = set(check(val) for val in col_vals) + if all(results): + column_types.append(check.field()) + break + + return column_types + + +class Loader(_CSVReader): + """ + Load the contents of a CSV file into a database and return a model class + suitable for working with the CSV data. + + :param db_or_model: a peewee Database instance or a Model class. + :param str filename: the filename of the CSV file. + :param list fields: A list of peewee Field() instances appropriate to + the values in the CSV file. + :param list field_names: A list of names to use for the fields. + :param bool has_header: Whether the first row of the CSV file is a header. + :param int sample_size: Number of rows to introspect if fields are not + defined. + :param converter: A RowConverter instance to use. + :param str db_table: Name of table to store data in (if not specified, the + table name will be derived from the CSV filename). + :param reader_kwargs: Arbitrary arguments to pass to the CSV reader. + """ + def __init__(self, db_or_model, filename, fields=None, field_names=None, + has_header=True, sample_size=10, converter=None, + db_table=None, **reader_kwargs): + self.filename = filename + self.fields = fields + self.field_names = field_names + self.has_header = has_header + self.sample_size = sample_size + self.converter = converter + self.reader_kwargs = reader_kwargs + + if isinstance(db_or_model, Database): + self.database = db_or_model + self.model = None + self.db_table = (db_table or + os.path.splitext(os.path.basename(filename))[0]) + else: + self.model = db_or_model + self.database = self.model._meta.database + self.db_table = self.model._meta.db_table + self.fields = self.model._meta.get_fields() + self.field_names = self.model._meta.get_field_names() + # If using an auto-incrementing primary key, ignore it. + if self.model._meta.auto_increment: + self.fields = self.fields[1:] + self.field_names = self.field_names[1:] + + def clean_field_name(self, s): + return re.sub('[^a-z0-9]+', '_', s.lower()) + + def get_converter(self): + return self.converter or RowConverter( + self.database, + has_header=self.has_header, + sample_size=self.sample_size) + + def analyze_csv(self): + converter = self.get_converter() + header, rows = converter.extract_rows( + self.filename, + **self.reader_kwargs) + if rows: + self.fields = converter.analyze(rows) + else: + self.fields = [converter.default.field() for _ in header] + if not self.field_names: + self.field_names = map(self.clean_field_name, header) + + def get_model_class(self, field_names, fields): + if self.model: + return self.model + attrs = dict(zip(field_names, fields)) + attrs['_auto_pk'] = PrimaryKeyField() + klass = type(self.db_table.title(), (Model,), attrs) + klass._meta.database = self.database + klass._meta.db_table = self.db_table + return klass + + def load(self): + if not self.fields: + self.analyze_csv() + if not self.field_names and not self.has_header: + self.field_names = [ + 'field_%d' % i for i in range(len(self.fields))] + + with self.get_reader(self.filename, **self.reader_kwargs) as reader: + if not self.field_names: + self.field_names = map(self.clean_field_name, reader.next()) + elif self.has_header: + reader.next() + + ModelClass = self.get_model_class(self.field_names, self.fields) + + with self.database.transaction(): + ModelClass.create_table(True) + for row in reader: + insert = {} + for field_name, value in zip(self.field_names, row): + if value: + insert[field_name] = value.decode('utf-8') + if insert: + ModelClass.insert(**insert).execute() + + return ModelClass + +def load_csv(db_or_model, filename, fields=None, field_names=None, + has_header=True, sample_size=10, converter=None, + db_table=None, **reader_kwargs): + loader = Loader(db_or_model, filename, fields, field_names, has_header, + sample_size, converter, db_table, **reader_kwargs) + return loader.load() +load_csv.__doc__ = Loader.__doc__ diff --git a/playhouse/djpeewee.py b/playhouse/djpeewee.py new file mode 100644 index 000000000..24d2d309b --- /dev/null +++ b/playhouse/djpeewee.py @@ -0,0 +1,197 @@ +""" +Simple translation of Django model classes to peewee model classes. +""" +from functools import partial +import logging + +from peewee import * + +logger = logging.getLogger('peewee.playhouse.djpeewee') + +class AttrDict(dict): + def __getattr__(self, attr): + return self[attr] + +class DjangoTranslator(object): + def __init__(self): + self._field_map = self.get_django_field_map() + + def get_django_field_map(self): + from django.db.models import fields as djf + return [ + (djf.AutoField, PrimaryKeyField), + (djf.BigIntegerField, BigIntegerField), + #(djf.BinaryField, BlobField), + (djf.BooleanField, BooleanField), + (djf.CharField, CharField), + (djf.DateTimeField, DateTimeField), # Extends DateField. + (djf.DateField, DateField), + (djf.DecimalField, DecimalField), + (djf.FilePathField, CharField), + (djf.FloatField, FloatField), + (djf.IntegerField, IntegerField), + (djf.NullBooleanField, partial(BooleanField, null=True)), + (djf.TextField, TextField), + (djf.TimeField, TimeField), + (djf.related.ForeignKey, ForeignKeyField), + ] + + def convert_field(self, field): + converted = None + for django_field, peewee_field in self._field_map: + if isinstance(field, django_field): + converted = peewee_field + break + return converted + + def _translate_model(self, + model, + mapping, + max_depth=None, + backrefs=False, + exclude=None): + if exclude and model in exclude: + return + + if max_depth is None: + max_depth = -1 + + from django.db.models import fields as djf + options = model._meta + if mapping.get(options.object_name): + return + mapping[options.object_name] = None + + attrs = {} + # Sort fields such that nullable fields appear last. + field_key = lambda field: (field.null and 1 or 0, field) + for model_field in sorted(options.fields, key=field_key): + # Get peewee equivalent for this field type. + converted = self.convert_field(model_field) + + # Special-case ForeignKey fields. + if converted is ForeignKeyField: + if max_depth != 0: + related_model = model_field.rel.to + model_name = related_model._meta.object_name + # If we haven't processed the related model yet, do so now. + if model_name not in mapping: + mapping[model_name] = None # Avoid endless recursion. + self._translate_model( + related_model, + mapping, + max_depth=max_depth - 1, + backrefs=backrefs, + exclude=exclude) + if mapping[model_name] is None: + # Cycle detected, put an integer field here. + logger.warn('Cycle detected: %s: %s', + model_field.name, model_name) + attrs[model_field.name] = IntegerField( + db_column=model_field.get_attname()) + else: + related_name = (model_field.rel.related_name or + model_field.related_query_name()) + if related_name.endswith('+'): + related_name = '__%s:%s:%s' % ( + options, + model_field.name, + related_name.strip('+')) + + attrs[model_field.name] = ForeignKeyField( + mapping[model_name], + related_name=related_name) + + else: + attrs[model_field.name] = IntegerField( + db_column=model_field.get_attname()) + + elif converted: + attrs[model_field.name] = converted() + + klass = type(options.object_name, (Model,), attrs) + klass._meta.db_table = options.db_table + klass._meta.database.interpolation = '%s' + mapping[options.object_name] = klass + + if backrefs: + # Follow back-references for foreign keys. + for rel_obj in options.get_all_related_objects(): + if rel_obj.model._meta.object_name in mapping: + continue + self._translate_model( + rel_obj.model, + mapping, + max_depth=max_depth - 1, + backrefs=backrefs, + exclude=exclude) + + # Load up many-to-many relationships. + for many_to_many in options.many_to_many: + if not isinstance(many_to_many, djf.related.ManyToManyField): + continue + self._translate_model( + many_to_many.rel.through, + mapping, + max_depth=max_depth, # Do not decrement. + backrefs=backrefs, + exclude=exclude) + + + def translate_models(self, *models, **options): + """ + Generate a group of peewee models analagous to the provided Django models + for the purposes of creating queries. + + :param model: A Django model class. + :param options: A dictionary of options, see note below. + :returns: A dictionary mapping model names to peewee model classes. + :rtype: dict + + Recognized options: + `recurse`: Follow foreign keys (default: True) + `max_depth`: Max depth to recurse (default: None, unlimited) + `backrefs`: Follow backrefs (default: False) + `exclude`: A list of models to exclude + + Example:: + + # Map Django models to peewee models. Foreign keys and M2M will be + # traversed as well. + peewee = translate(Account) + + # Generate query using peewee. + PUser = peewee['User'] + PAccount = peewee['Account'] + query = PUser.select().join(PAccount).where(PAccount.acct_type == 'foo') + + # Django raw query. + users = User.objects.raw(*query.sql()) + """ + mapping = AttrDict() + recurse = options.get('recurse', True) + max_depth = options.get('max_depth', None) + backrefs = options.get('backrefs', False) + exclude = options.get('exclude', None) + if not recurse and max_depth: + raise ValueError('Error, you cannot specify a max_depth when ' + 'recurse=False.') + elif not recurse: + max_depth = 0 + elif recurse and max_depth is None: + max_depth = -1 + + for model in models: + self._translate_model( + model, + mapping, + max_depth=max_depth, + backrefs=backrefs, + exclude=exclude) + return mapping + +try: + import django + translate = DjangoTranslator().translate_models +except ImportError: + pass diff --git a/playhouse/tests_csv_loader.py b/playhouse/tests_csv_loader.py new file mode 100644 index 000000000..75aa63973 --- /dev/null +++ b/playhouse/tests_csv_loader.py @@ -0,0 +1,114 @@ +import csv +import unittest +from contextlib import contextmanager +from datetime import date +from StringIO import StringIO +from textwrap import dedent + +from playhouse.csv_loader import * + + +class TestRowConverter(RowConverter): + @contextmanager + def get_reader(self, csv_data, **reader_kwargs): + reader = csv.reader(StringIO(csv_data), **reader_kwargs) + yield reader + +class TestLoader(Loader): + @contextmanager + def get_reader(self, csv_data, **reader_kwargs): + reader = csv.reader(StringIO(csv_data), **reader_kwargs) + yield reader + + def get_converter(self): + return self.converter or TestRowConverter( + self.database, + has_header=self.has_header, + sample_size=self.sample_size) + +db = SqliteDatabase(':memory:') + +class TestCSVConversion(unittest.TestCase): + header = 'id,name,dob,salary,is_admin' + simple = '10,"F1 L1",1983-01-01,10000,t' + float_sal = '20,"F2 L2",1983-01-02,20000.5,f' + only_name = ',"F3 L3",,,' + mismatch = 'foo,F4 L4,dob,sal,x' + + def setUp(self): + db.execute_sql('drop table if exists csv_test;') + + def build_csv(self, *lines): + return '\r\n'.join(lines) + + def load(self, *lines, **loader_kwargs): + csv = self.build_csv(*lines) + loader_kwargs['filename'] = csv + loader_kwargs.setdefault('db_table', 'csv_test') + loader_kwargs.setdefault('db_or_model', db) + return TestLoader(**loader_kwargs).load() + + def assertData(self, ModelClass, expected): + name_field = ModelClass._meta.get_fields()[2] + query = ModelClass.select().order_by(name_field).tuples() + self.assertEqual([row[1:] for row in query], expected) + + def test_defaults(self): + ModelClass = self.load( + self.header, + self.simple, + self.float_sal, + self.only_name) + self.assertData(ModelClass, [ + (10, 'F1 L1', date(1983, 1, 1), 10000., 't'), + (20, 'F2 L2', date(1983, 1, 2), 20000.5, 'f'), + (0, 'F3 L3', None, 0., ''), + ]) + + def test_no_header(self): + ModelClass = self.load( + self.simple, + self.float_sal, + field_names=['f1', 'f2', 'f3', 'f4', 'f5'], + has_header=False) + self.assertEqual(ModelClass._meta.get_field_names(), [ + '_auto_pk', 'f1', 'f2', 'f3', 'f4', 'f5']) + self.assertData(ModelClass, [ + (10, 'F1 L1', date(1983, 1, 1), 10000., 't'), + (20, 'F2 L2', date(1983, 1, 2), 20000.5, 'f')]) + + def test_no_header_no_fieldnames(self): + ModelClass = self.load( + self.simple, + self.float_sal, + has_header=False) + self.assertEqual(ModelClass._meta.get_field_names(), [ + '_auto_pk', 'field_0', 'field_1', 'field_2', 'field_3', 'field_4']) + + def test_mismatch_types(self): + ModelClass = self.load( + self.header, + self.simple, + self.mismatch) + self.assertData(ModelClass, [ + ('10', 'F1 L1', '1983-01-01', '10000', 't'), + ('foo', 'F4 L4', 'dob', 'sal', 'x')]) + + def test_fields(self): + fields = [ + IntegerField(), + CharField(), + DateField(), + FloatField(), + CharField()] + ModelClass = self.load( + self.header, + self.simple, + self.float_sal, + fields=fields) + self.assertEqual( + map(type, fields), + map(type, ModelClass._meta.get_fields()[1:])) + self.assertData(ModelClass, [ + (10, 'F1 L1', date(1983, 1, 1), 10000., 't'), + (20, 'F2 L2', date(1983, 1, 2), 20000.5, 'f')]) diff --git a/playhouse/tests_djpeewee.py b/playhouse/tests_djpeewee.py new file mode 100644 index 000000000..58c531c55 --- /dev/null +++ b/playhouse/tests_djpeewee.py @@ -0,0 +1,263 @@ +from datetime import timedelta +import unittest + +from peewee import * +from peewee import print_ +try: + import django +except ImportError: + django = None + + +if django is not None: + from django.conf import settings + settings.configure( + DATABASES={ + 'default': { + 'engine': 'django.db.backends.sqlite3', + 'name': ':memory:'}}, + ) + from django.db import models + from playhouse.djpeewee import translate + + # Django model definitions. + class Simple(models.Model): + char_field = models.CharField(max_length=1) + int_field = models.IntegerField() + + class User(models.Model): + username = models.CharField(max_length=255) + + class Meta: + db_table = 'user_tbl' + + class Post(models.Model): + author = models.ForeignKey(User, related_name='posts') + content = models.TextField() + + class Comment(models.Model): + post = models.ForeignKey(Post, related_name='comments') + commenter = models.ForeignKey(User, related_name='comments') + comment = models.TextField() + + class Tag(models.Model): + tag = models.CharField() + posts = models.ManyToManyField(Post) + + class Event(models.Model): + start_time = models.DateTimeField() + end_time = models.DateTimeField() + title = models.CharField() + + class Meta: + db_table = 'events_tbl' + + class Category(models.Model): + parent = models.ForeignKey('self', null=True) + + class A(models.Model): + a_field = models.IntegerField() + b = models.ForeignKey('B', null=True, related_name='as') + + class B(models.Model): + a = models.ForeignKey(A, related_name='bs') + + class C(models.Model): + b = models.ForeignKey(B, related_name='cs') + + class Parent(models.Model): + pass + + class Child(Parent): + pass + + + class TestDjPeewee(unittest.TestCase): + def assertFields(self, model, expected): + self.assertEqual(len(model._meta.fields), len(expected)) + zipped = zip(model._meta.get_fields(), expected) + for (model_field, (name, field_type)) in zipped: + self.assertEqual(model_field.name, name) + self.assertTrue(type(model_field) is field_type) + + def test_simple(self): + P = translate(Simple) + self.assertEqual(P.keys(), ['Simple']) + self.assertFields(P['Simple'], [ + ('id', PrimaryKeyField), + ('char_field', CharField), + ('int_field', IntegerField), + ]) + + def test_graph(self): + P = translate(User, Tag, Comment) + self.assertEqual(sorted(P.keys()), [ + 'Comment', + 'Post', + 'Tag', + 'Tag_posts', + 'User']) + + # Test the models that were found. + user = P['User'] + self.assertFields(user, [ + ('id', PrimaryKeyField), + ('username', CharField)]) + self.assertEqual(user.posts.rel_model, P['Post']) + self.assertEqual(user.comments.rel_model, P['Comment']) + + post = P['Post'] + self.assertFields(post, [ + ('id', PrimaryKeyField), + ('author', ForeignKeyField), + ('content', TextField)]) + self.assertEqual(post.comments.rel_model, P['Comment']) + + comment = P['Comment'] + self.assertFields(comment, [ + ('id', PrimaryKeyField), + ('post', ForeignKeyField), + ('commenter', ForeignKeyField), + ('comment', TextField)]) + + tag = P['Tag'] + self.assertFields(tag, [ + ('id', PrimaryKeyField), + ('tag', CharField)]) + + thru = P['Tag_posts'] + self.assertFields(thru, [ + ('id', PrimaryKeyField), + ('tag', ForeignKeyField), + ('post', ForeignKeyField)]) + + def test_fk_query(self): + trans = translate(User, Post, Comment, Tag) + U = trans['User'] + P = trans['Post'] + C = trans['Comment'] + + query = (U.select() + .join(P) + .join(C) + .where(C.comment == 'test')) + sql, params = query.sql() + self.assertEqual( + sql, + 'SELECT t1."id", t1."username" FROM "user_tbl" AS t1 ' + 'INNER JOIN "playhouse_post" AS t2 ' + 'ON (t1."id" = t2."author_id") ' + 'INNER JOIN "playhouse_comment" AS t3 ' + 'ON (t2."id" = t3."post_id") WHERE (t3."comment" = %s)') + self.assertEqual(params, ['test']) + + def test_m2m_query(self): + trans = translate(Post, Tag) + P = trans['Post'] + U = trans['User'] + T = trans['Tag'] + TP = trans['Tag_posts'] + + query = (P.select() + .join(TP) + .join(T) + .where(T.tag == 'test')) + sql, params = query.sql() + self.assertEqual( + sql, + 'SELECT t1."id", t1."author_id", t1."content" ' + 'FROM "playhouse_post" AS t1 ' + 'INNER JOIN "playhouse_tag_posts" AS t2 ' + 'ON (t1."id" = t2."post_id") ' + 'INNER JOIN "playhouse_tag" AS t3 ' + 'ON (t2."tag_id" = t3."id") WHERE (t3."tag" = %s)') + self.assertEqual(params, ['test']) + + def test_docs_example(self): + # The docs don't lie. + PEvent = translate(Event)['Event'] + hour = timedelta(hours=1) + query = (PEvent + .select() + .where( + (PEvent.end_time - PEvent.start_time) > hour)) + sql, params = query.sql() + self.assertEqual( + sql, + 'SELECT t1."id", t1."start_time", t1."end_time", t1."title" ' + 'FROM "events_tbl" AS t1 ' + 'WHERE ((t1."end_time" - t1."start_time") > %s)') + self.assertEqual(params, [hour]) + + def test_self_referential(self): + trans = translate(Category) + self.assertFields(trans['Category'], [ + ('id', PrimaryKeyField), + ('parent', IntegerField)]) + + def test_cycle(self): + trans = translate(A) + self.assertFields(trans['A'], [ + ('id', PrimaryKeyField), + ('a_field', IntegerField), + ('b', ForeignKeyField)]) + self.assertFields(trans['B'], [ + ('id', PrimaryKeyField), + ('a', IntegerField)]) + + trans = translate(B) + self.assertFields(trans['A'], [ + ('id', PrimaryKeyField), + ('a_field', IntegerField), + ('b', IntegerField)]) + self.assertFields(trans['B'], [ + ('id', PrimaryKeyField), + ('a', ForeignKeyField)]) + + def test_max_depth(self): + trans = translate(C, max_depth=1) + self.assertFields(trans['C'], [ + ('id', PrimaryKeyField), + ('b', ForeignKeyField)]) + self.assertFields(trans['B'], [ + ('id', PrimaryKeyField), + ('a', IntegerField)]) + + def test_exclude(self): + trans = translate(Comment, exclude=(User,)) + self.assertFields(trans['Post'], [ + ('id', PrimaryKeyField), + ('author', IntegerField), + ('content', TextField)]) + self.assertEqual( + trans['Post'].comments.rel_model, + trans['Comment']) + + self.assertFields(trans['Comment'], [ + ('id', PrimaryKeyField), + ('post', ForeignKeyField), + ('commenter', IntegerField), + ('comment', TextField)]) + + def test_backrefs(self): + trans = translate(User, backrefs=True) + self.assertEqual(sorted(trans.keys()), [ + 'Comment', + 'Post', + 'User']) + + def test_inheritance(self): + trans = translate(Parent) + self.assertEqual(trans.keys(), ['Parent']) + self.assertFields(trans['Parent'], [ + ('id', PrimaryKeyField),]) + + trans = translate(Child) + self.assertEqual(sorted(trans.keys()), ['Child', 'Parent']) + self.assertFields(trans['Child'], [ + ('id', PrimaryKeyField), + ('parent_ptr', ForeignKeyField)]) + + +else: + print_('Skipping djpeewee tests, Django not found.') diff --git a/pwiz.py b/pwiz.py index a86a25a0e..9251f576a 100755 --- a/pwiz.py +++ b/pwiz.py @@ -250,7 +250,7 @@ class SqliteIntrospector(Introspector): 'text': TextField, 'time': TimeField, } - re_foreign_key = '"?(.+?)"?\s+.+\s+references (.*) \(["|]?(.*)["|]?\)' + re_foreign_key = '["\[]?(.+?)["\]]?\s+.+\s+references ["\[]?(.+?)["\]]? \(["|\[]?(.+?)["|\]]?\)' re_varchar = r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$' def get_conn_class(self): diff --git a/runtests.py b/runtests.py index 228fbaacd..954fd601b 100755 --- a/runtests.py +++ b/runtests.py @@ -25,6 +25,8 @@ def get_option_parser(): cases = optparse.OptionGroup(parser, 'Individual test module options') cases.add_option('--apsw', dest='apsw', default=False, action='store_true', help='apsw tests (requires apsw)') + cases.add_option('--csv', dest='csv', default=False, action='store_true', help='csv tests') + cases.add_option('--djpeewee', dest='djpeewee', default=False, action='store_true', help='djpeewee tests') cases.add_option('--gfk', dest='gfk', default=False, action='store_true', help='gfk tests') cases.add_option('--kv', dest='kv', default=False, action='store_true', help='key/value store tests') cases.add_option('--migrations', dest='migrations', default=False, action='store_true', help='migration helper tests (requires psycopg2)') @@ -50,6 +52,12 @@ def collect_modules(options): modules.append(tests_apsw) except ImportError: print_('Unable to import apsw tests, skipping') + if xtra(options.csv): + from playhouse import tests_csv_loader + modules.append(tests_csv_loader) + if xtra(options.djpeewee): + from playhouse import tests_djpeewee + modules.append(tests_djpeewee) if xtra(options.gfk): from playhouse import tests_gfk modules.append(tests_gfk) diff --git a/setup.py b/setup.py index e97860b7c..c4a3c3223 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name='peewee', - version="2.1.5", + version='2.1.6', description='a little orm', long_description=readme, author='Charles Leifer', diff --git a/tests.py b/tests.py index 8cd7ffadd..dd019b0b6 100644 --- a/tests.py +++ b/tests.py @@ -1,6 +1,5 @@ # encoding=utf-8 -from __future__ import with_statement import datetime import decimal import logging @@ -1545,6 +1544,25 @@ def test_fk_exceptions(self): self.assertEqual(b.user, u) self.assertRaises(User.DoesNotExist, getattr, b2, 'user') + def test_fk_cache_invalidated(self): + u1 = self.create_user('u1') + u2 = self.create_user('u2') + b = Blog.create(user=u1, title='b') + + blog = Blog.get(Blog.pk == b) + qc = len(self.queries()) + self.assertEqual(blog.user.id, u1.id) + self.assertEqual(len(self.queries()), qc + 1) + + blog.user = u2.id + self.assertEqual(blog.user.id, u2.id) + self.assertEqual(len(self.queries()), qc + 2) + + # No additional query. + blog.user = u2.id + self.assertEqual(blog.user.id, u2.id) + self.assertEqual(len(self.queries()), qc + 2) + def test_fk_ints(self): c1 = Category.create(name='c1') c2 = Category.create(name='c2', parent=c1.id) @@ -1736,6 +1754,9 @@ def test_counting(self): uc = User.select().where(User.username == 'u1').join(Blog).distinct().count() self.assertEqual(uc, 1) + self.assertEqual( + User.select().limit(1).wrapped_count(clear_limit=False), 1) + def test_ordering(self): u1 = User.create(username='u1') u2 = User.create(username='u2') @@ -3149,3 +3170,38 @@ def test_sequence_shared(self): elif TEST_VERBOSITY > 0: print_('Skipping "sequence" tests') + +if database_class is PostgresqlDatabase: + class TestUnicodeConversion(ModelTestCase): + requires = [User] + + def setUp(self): + super(TestUnicodeConversion, self).setUp() + + # Create a user object with UTF-8 encoded username. + ustr = ulit('Ísland') + self.user = User.create(username=ustr) + + def tearDown(self): + super(TestUnicodeConversion, self).tearDown() + test_db.register_unicode = True + test_db.close() + + def reset_encoding(self, encoding): + test_db.close() + conn = test_db.get_conn() + conn.set_client_encoding(encoding) + + def test_unicode_conversion(self): + # Turn off unicode conversion on a per-connection basis. + test_db.register_unicode = False + self.reset_encoding('LATIN1') + + u = User.get(User.id == self.user.id) + self.assertFalse(u.username == self.user.username) + + test_db.register_unicode = True + self.reset_encoding('LATIN1') + + u = User.get(User.id == self.user.id) + self.assertEqual(u.username, self.user.username)