From be281015f3051edc39baa8bd5247c6665d06643d Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Sat, 19 Oct 2013 09:42:43 -0500 Subject: [PATCH 01/39] Docs for sql_error_handler. --- docs/peewee/api.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/peewee/api.rst b/docs/peewee/api.rst index 2339890ae..3ec787973 100644 --- a/docs/peewee/api.rst +++ b/docs/peewee/api.rst @@ -1496,6 +1496,23 @@ Database and its subclasses :param Field date_field: field instance storing a datetime, date or time. :rtype: an expression object. + .. py:method:: sql_error_handler(exception, sql, params, require_commit) + + This hook is called when an error is raised executing a query, allowing + your application to inject custom error handling behavior. The default + implementation simply reraises the exception. + + .. code-block:: python + + class SqliteDatabaseCustom(SqliteDatabase): + def sql_error_handler(self, exception, sql, params, require_commit): + # Perform some custom behavior, for example close the + # connection to the database. + self.close() + + # Re-raise the exception. + raise exception + .. py:class:: SqliteDatabase(Database) From fa487a97af4f5821b757499a1decec544a25a351 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Sat, 19 Oct 2013 09:50:30 -0500 Subject: [PATCH 02/39] Documenting composite key. --- docs/peewee/api.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/peewee/api.rst b/docs/peewee/api.rst index 3ec787973..e13b206da 100644 --- a/docs/peewee/api.rst +++ b/docs/peewee/api.rst @@ -606,6 +606,23 @@ Fields related to. +.. py:class:: CompositeKey(*fields) + + Specify a composite primary key for a model. Unlike the other fields, a + composite key is defined in the model's ``Meta`` class after the fields + have been defined. It takes as parameters the string names of the fields + to use as the primary key: + + .. code-block:: python + + class BlogTagThrough(Model): + blog = ForeignKeyField(Blog, related_name='tags') + tag = ForeignKeyField(Tag, related_name='blogs') + + class Meta: + primary_key = CompositeKey('blog', 'tag') + + .. _query-types: Query Types From 262bdbb7b78153ec4c656553fc6479e35b9abb34 Mon Sep 17 00:00:00 2001 From: Alexey Shamrin Date: Wed, 23 Oct 2013 00:55:55 +0400 Subject: [PATCH 03/39] cookbook.rst: clarify threadlocals vs check_same_threads --- docs/peewee/cookbook.rst | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/peewee/cookbook.rst b/docs/peewee/cookbook.rst index d855363a0..10bc59ea1 100644 --- a/docs/peewee/cookbook.rst +++ b/docs/peewee/cookbook.rst @@ -178,9 +178,11 @@ instantiate your database with ``threadlocals=True`` (*recommended*): concurrent_db = SqliteDatabase('stats.db', threadlocals=True) -The above implementation stores connection state in a thread local and will only -use that connection for a given thread. Pysqlite can share a connection across -threads, so if you would prefer to reuse a connection in multiple threads: +With the above peewee stores connection state in a thread local; each thread gets its +own separate connection. + +Alternatively, Python sqlite3 module can share a connection across different threads, +but you have to disable runtime checks to reuse the single connection: .. code-block:: python From f1607416474ae99ccc16b058172cbce9d0ebd230 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 23 Oct 2013 11:03:00 -0500 Subject: [PATCH 04/39] Register unicode converters on a per-connection basis when using postgresql. --- peewee.py | 12 ++++++++---- tests.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/peewee.py b/peewee.py index 7f6859b5f..083b23820 100644 --- a/peewee.py +++ b/peewee.py @@ -123,9 +123,7 @@ def _sqlite_date_part(lookup_type, datetime_string): return getattr(dt, lookup_type) if psycopg2: - import psycopg2.extensions - psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) - psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY) + from psycopg2 import extensions as pg_extensions # Peewee logger = logging.getLogger('peewee') @@ -2005,10 +2003,16 @@ class PostgresqlDatabase(Database): reserved_tables = ['user'] sequences = True + register_unicode = True + def _connect(self, database, **kwargs): if not psycopg2: raise ImproperlyConfigured('psycopg2 must be installed.') - return psycopg2.connect(database=database, **kwargs) + conn = psycopg2.connect(database=database, **kwargs) + if self.register_unicode: + pg_extensions.register_type(pg_extensions.UNICODE, conn) + pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn) + return conn def last_insert_id(self, cursor, model): seq = model._meta.primary_key.sequence diff --git a/tests.py b/tests.py index 8cd7ffadd..8c7a6388d 100644 --- a/tests.py +++ b/tests.py @@ -3149,3 +3149,38 @@ def test_sequence_shared(self): elif TEST_VERBOSITY > 0: print_('Skipping "sequence" tests') + +if database_class is PostgresqlDatabase: + class TestUnicodeConversion(ModelTestCase): + requires = [User] + + def setUp(self): + super(TestUnicodeConversion, self).setUp() + + # Create a user object with UTF-8 encoded username. + ustr = ulit('Ísland') + self.user = User.create(username=ustr) + + def tearDown(self): + super(TestUnicodeConversion, self).tearDown() + test_db.register_unicode = True + test_db.close() + + def reset_encoding(self, encoding): + test_db.close() + conn = test_db.get_conn() + conn.set_client_encoding(encoding) + + def test_unicode_conversion(self): + # Turn off unicode conversion on a per-connection basis. + test_db.register_unicode = False + self.reset_encoding('LATIN1') + + u = User.get(User.id == self.user.id) + self.assertFalse(u.username == self.user.username) + + test_db.register_unicode = True + self.reset_encoding('LATIN1') + + u = User.get(User.id == self.user.id) + self.assertEqual(u.username, self.user.username) From ecc87583d57b86c1dba6d1983899c64692043c2a Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 23 Oct 2013 12:25:34 -0500 Subject: [PATCH 05/39] Invalidate related obj cache, fixes #243 --- peewee.py | 3 +++ tests.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/peewee.py b/peewee.py index 083b23820..5a5c0fd4c 100644 --- a/peewee.py +++ b/peewee.py @@ -630,7 +630,10 @@ def __set__(self, instance, value): instance._data[self.att_name] = value.get_id() instance._obj_cache[self.att_name] = value else: + orig_value = instance._data.get(self.att_name) instance._data[self.att_name] = value + if orig_value != value and self.att_name in instance._obj_cache: + del instance._obj_cache[self.att_name] class ReverseRelationDescriptor(object): def __init__(self, field): diff --git a/tests.py b/tests.py index 8c7a6388d..2b3164fce 100644 --- a/tests.py +++ b/tests.py @@ -1545,6 +1545,25 @@ def test_fk_exceptions(self): self.assertEqual(b.user, u) self.assertRaises(User.DoesNotExist, getattr, b2, 'user') + def test_fk_cache_invalidated(self): + u1 = self.create_user('u1') + u2 = self.create_user('u2') + b = Blog.create(user=u1, title='b') + + blog = Blog.get(Blog.pk == b) + qc = len(self.queries()) + self.assertEqual(blog.user.id, u1.id) + self.assertEqual(len(self.queries()), qc + 1) + + blog.user = u2.id + self.assertEqual(blog.user.id, u2.id) + self.assertEqual(len(self.queries()), qc + 2) + + # No additional query. + blog.user = u2.id + self.assertEqual(blog.user.id, u2.id) + self.assertEqual(len(self.queries()), qc + 2) + def test_fk_ints(self): c1 = Category.create(name='c1') c2 = Category.create(name='c2', parent=c1.id) From 7ccdab51ee546877c85b9cd633dadef37a45e9e6 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Fri, 1 Nov 2013 09:04:36 -0500 Subject: [PATCH 06/39] Adding note --- docs/index.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index ebababf78..d3a4acacf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -45,6 +45,14 @@ Contents: peewee/playhouse peewee/upgrading +Note +---- + +Hi, I'm Charlie the author of peewee. I hope that you enjoy using this library +as much as I've enjoyed writing it. Peewee wouldn't be anywhere near as useful +without people like you, so thank you. If you find any bugs, odd behavior, or +have an idea for a new feature please don't hesitate to `open an issue `_ on GitHub. + Indices and tables ================== From 0ec5e726ed028271c201fbfd2c4ae9a0bc353e9a Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Mon, 4 Nov 2013 13:45:30 -0600 Subject: [PATCH 07/39] Use new theme. --- docs/conf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/conf.py b/docs/conf.py index 4bf7ca7e0..a1280467c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,6 +11,8 @@ # All configuration values have a default; values that are commented out # serve to show the default. +RTD_NEW_THEME = True + import sys, os # If extensions (or modules to document with autodoc) are in another directory, From 12016ad3889645ef0810a4fe426379cf8493e052 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 6 Nov 2013 21:15:34 -0600 Subject: [PATCH 08/39] Adding a "bare" field for use with sqlite. --- peewee.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/peewee.py b/peewee.py index 5a5c0fd4c..a7ed2f281 100644 --- a/peewee.py +++ b/peewee.py @@ -21,6 +21,7 @@ from inspect import isclass __all__ = [ + 'BareField', 'BigIntegerField', 'BlobField', 'BooleanField', @@ -431,6 +432,10 @@ def python_value(self, value): def __hash__(self): return hash(self.name + '.' + self.model_class.__name__) +class BareField(Field): + db_field = 'bare' + template = '' + class IntegerField(Field): db_field = 'int' coerce = int @@ -737,6 +742,7 @@ def __set__(self, instance, value): class QueryCompiler(object): field_map = { + 'bare': '', 'bigint': 'BIGINT', 'blob': 'BLOB', 'bool': 'SMALLINT', From 66e632e8f72e38f2368254e43f3f88b75b2c75b0 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 6 Nov 2013 21:15:46 -0600 Subject: [PATCH 09/39] Adding a CSV loader to playhouse. --- playhouse/csv_loader.py | 209 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 playhouse/csv_loader.py diff --git a/playhouse/csv_loader.py b/playhouse/csv_loader.py new file mode 100644 index 000000000..1dc1c3f6f --- /dev/null +++ b/playhouse/csv_loader.py @@ -0,0 +1,209 @@ +""" +Peewee helper for loading CSV data into a database. + +db = SqliteDatabase(':memory:') + +# Load the users CSV file into the database and return a Model for accessing +# the data. +User = load_csv.load(db, 'users.csv') + +# Provide column types. +Payments = load_csv.load(db, 'payments.csv', (CharField, DecimalField)) +""" +import csv +import datetime +import os +import re +from collections import OrderedDict + +from peewee import * + + +class RowConverter(object): + """ + Simple introspection utility to convert a CSV file into a list of headers + and column types. + """ + date_formats = [ + '%Y-%m-%d', + '%m/%d/%Y'] + + datetime_formats = [ + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f'] + + def __init__(self, database, has_header=True, sample_size=10): + self.database = database + self.has_header = has_header + self.sample_size = sample_size + + def matches_date(self, value, formats): + for fmt in formats: + try: + datetime.datetime.strptime(value, fmt) + except ValueError: + pass + else: + return True + + def field(field_class, **field_kwargs): + def decorator(fn): + fn.field = lambda: field_class(**field_kwargs) + return fn + return decorator + + @field(IntegerField, default=0) + def is_integer(self, value): + return value.isdigit() + + @field(FloatField, default=0) + def is_float(self, value): + try: + float(value) + except (ValueError, TypeError): + pass + else: + return True + + @field(DateTimeField, null=True) + def is_datetime(self, value): + return self.matches_date(value, self.datetime_formats) + + @field(DateField, null=True) + def is_date(self, value): + return self.matches_date(value, self.date_formats) + + @field(BooleanField, default=False) + def is_boolean(self, value): + return value.lower() in ('t', 'f', 'true', 'false') + + @field(BareField, null=True) + def default(self, value): + return True + + def extract_rows(self, filename, **reader_kwargs): + """ + Extract `self.sample_size` rows from the CSV file and analyze their + data-types. + + :param str filename: A string filename. + :param reader_kwargs: Arbitrary parameters to pass to the CSV reader. + :returns: A 2-tuple containing a list of headers and list of rows + read from the CSV file. + """ + rows = [] + rows_to_read = self.sample_size + with open(filename) as fh: + reader = csv.reader(fh, **reader_kwargs) + if self.has_header: + rows_to_read += 1 + for i, row in enumerate(reader): + rows.append(row) + if i == self.sample_size: + break + if self.has_header: + header, rows = rows[0], rows[1:] + else: + header = ['field_%d' for i in range(len(rows[0]))] + return header, rows + + def get_checks(self): + """Return a list of functions to use when testing values.""" + return [ + self.is_date, + self.is_datetime, + self.is_integer, + self.is_float, + self.is_boolean, + self.default] + + def analyze(self, rows): + """ + Analyze the given rows and try to determine the type of value stored. + + :param list rows: A list-of-lists containing one or more rows from a + csv file. + :returns: A list of peewee Field objects for each column in the CSV. + """ + transposed = zip(*rows) + checks = self.get_checks() + column_types = [] + for i, column in enumerate(transposed): + # Remove any empty values. + col_vals = [val for val in column if val != ''] + for check in checks: + results = set(check(val) for val in col_vals) + if all(results): + column_types.append(check.field()) + break + + return column_types + + +class Loader(object): + def __init__(self, database, filename, fields=None, field_names=None, + has_header=True, converter=None, **reader_kwargs): + self.database = database + self.filename = filename + self.fields = fields + self.field_names = field_names + self.has_header = has_header + self.converter = converter + self.reader_kwargs = reader_kwargs + + def clean_field_name(self, s): + return re.sub('[^a-z0-9]+', '_', s.lower()) + + def get_converter(self): + return self.converter or RowConverter( + self.database, + has_header=self.has_header) + + def analyze_csv(self): + converter = self.get_converter() + header, rows = converter.extract_rows(self.filename) + self.fields = converter.analyze(rows) + if not self.field_names: + self.field_names = map(self.clean_field_name, header) + + def get_reader(self): + fh = open(self.filename, 'r') + return csv.reader(fh, **self.reader_kwargs) + + def model_class(self, field_names, fields): + model_name = os.path.splitext(os.path.basename(self.filename))[0] + attrs = dict(zip(field_names, fields)) + klass = type(model_name.title(), (Model,), attrs) + klass._meta.database = self.database + return klass + + def load(self): + if not self.fields: + self.analyze_csv() + if not self.field_names and not self.has_header: + self.field_names = ['field_%s' for i in range(len(self.fields))] + + reader = self.get_reader() + if not self.field_names: + self.field_names = map(self.clean_field_name, reader.next()) + elif self.has_header: + reader.next() + + ModelClass = self.model_class(self.field_names, self.fields) + + with self.database.transaction(): + ModelClass.create_table(True) + for row in reader: + inst = ModelClass() + for i, field_name in enumerate(self.field_names): + if row[i]: + setattr(inst, field_name, row[i]) + inst.save() + + return ModelClass + +def load(database, filename, fields=None, field_names=None, has_header=True, + converter=None, **reader_kwargs): + loader = Loader(database, filename, fields, field_names, has_header, + converter, **reader_kwargs) + return loader.load() From c83a8318d2fc1ffafd9966d660133a65b1511bb6 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 6 Nov 2013 21:22:08 -0600 Subject: [PATCH 10/39] Adding docstrings. --- playhouse/csv_loader.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/playhouse/csv_loader.py b/playhouse/csv_loader.py index 1dc1c3f6f..9aabfa4a1 100644 --- a/playhouse/csv_loader.py +++ b/playhouse/csv_loader.py @@ -1,14 +1,18 @@ """ Peewee helper for loading CSV data into a database. -db = SqliteDatabase(':memory:') +Load the users CSV file into the database and return a Model for accessing +the data: -# Load the users CSV file into the database and return a Model for accessing -# the data. -User = load_csv.load(db, 'users.csv') + from playhouse.csv_loader import load_csv + db = SqliteDatabase(':memory:') + User = load_csv(db, 'users.csv') -# Provide column types. -Payments = load_csv.load(db, 'payments.csv', (CharField, DecimalField)) +Provide explicit field types and/or field names: + + fields = [IntegerField(), IntegerField(), DateTimeField(), DecimalField()] + field_names = ['from_acct', 'to_acct', 'timestamp', 'amount'] + Payments = load_csv(db, 'payments.csv', fields, field_names) """ import csv import datetime @@ -141,6 +145,19 @@ def analyze(self, rows): class Loader(object): + """ + Load the contents of a CSV file into a database and return a model class + suitable for working with the CSV data. + + :param database: a peewee Database instance. + :param str filename: the filename of the CSV file. + :param list fields: A list of peewee Field() instances appropriate to + the values in the CSV file. + :param list field_names: A list of names to use for the fields. + :param bool has_header: Whether the first row of the CSV file is a header. + :param converter: A RowConverter instance to use. + :param reader_kwargs: Arbitrary arguments to pass to the CSV reader. + """ def __init__(self, database, filename, fields=None, field_names=None, has_header=True, converter=None, **reader_kwargs): self.database = database @@ -202,8 +219,8 @@ def load(self): return ModelClass -def load(database, filename, fields=None, field_names=None, has_header=True, - converter=None, **reader_kwargs): +def load_csv(database, filename, fields=None, field_names=None, has_header=True, + converter=None, **reader_kwargs): loader = Loader(database, filename, fields, field_names, has_header, converter, **reader_kwargs) return loader.load() From 6eecf202451d5171c9459cc1a347910aab45ca6c Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 6 Nov 2013 21:31:53 -0600 Subject: [PATCH 11/39] Missing reader kwargs. --- playhouse/csv_loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/playhouse/csv_loader.py b/playhouse/csv_loader.py index 9aabfa4a1..c7c9e1078 100644 --- a/playhouse/csv_loader.py +++ b/playhouse/csv_loader.py @@ -178,7 +178,9 @@ def get_converter(self): def analyze_csv(self): converter = self.get_converter() - header, rows = converter.extract_rows(self.filename) + header, rows = converter.extract_rows( + self.filename, + **self.reader_kwargs) self.fields = converter.analyze(rows) if not self.field_names: self.field_names = map(self.clean_field_name, header) From 9c7c80b52a62b8ac2f1927938242ebd29e668120 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Thu, 7 Nov 2013 04:36:05 -0600 Subject: [PATCH 12/39] Cleanups. --- playhouse/csv_loader.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/playhouse/csv_loader.py b/playhouse/csv_loader.py index c7c9e1078..63031fb8e 100644 --- a/playhouse/csv_loader.py +++ b/playhouse/csv_loader.py @@ -81,7 +81,7 @@ def is_date(self, value): def is_boolean(self, value): return value.lower() in ('t', 'f', 'true', 'false') - @field(BareField, null=True) + @field(BareField, default='') def default(self, value): return True @@ -192,6 +192,7 @@ def get_reader(self): def model_class(self, field_names, fields): model_name = os.path.splitext(os.path.basename(self.filename))[0] attrs = dict(zip(field_names, fields)) + attrs['_auto_pk'] = PrimaryKeyField() klass = type(model_name.title(), (Model,), attrs) klass._meta.database = self.database return klass @@ -213,11 +214,13 @@ def load(self): with self.database.transaction(): ModelClass.create_table(True) for row in reader: - inst = ModelClass() - for i, field_name in enumerate(self.field_names): - if row[i]: - setattr(inst, field_name, row[i]) - inst.save() + insert = {} + for field_name, value in zip(self.field_names, row): + if value: + value = value.decode('utf-8') + insert[field_name] = value + if insert: + ModelClass.insert(**insert).execute() return ModelClass From e1af1957ecb093f779d1cc48bf2b83220f6b38ca Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Thu, 7 Nov 2013 05:16:52 -0600 Subject: [PATCH 13/39] Allow empty sample size. --- playhouse/csv_loader.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/playhouse/csv_loader.py b/playhouse/csv_loader.py index 63031fb8e..dfab4fcb7 100644 --- a/playhouse/csv_loader.py +++ b/playhouse/csv_loader.py @@ -27,6 +27,10 @@ class RowConverter(object): """ Simple introspection utility to convert a CSV file into a list of headers and column types. + + :param database: a peewee Database object. + :param bool has_header: whether the first row of CSV is a header row. + :param int sample_size: number of rows to introspect """ date_formats = [ '%Y-%m-%d', @@ -159,12 +163,14 @@ class Loader(object): :param reader_kwargs: Arbitrary arguments to pass to the CSV reader. """ def __init__(self, database, filename, fields=None, field_names=None, - has_header=True, converter=None, **reader_kwargs): + has_header=True, sample_size=10, converter=None, + **reader_kwargs): self.database = database self.filename = filename self.fields = fields self.field_names = field_names self.has_header = has_header + self.sample_size = sample_size self.converter = converter self.reader_kwargs = reader_kwargs @@ -174,14 +180,18 @@ def clean_field_name(self, s): def get_converter(self): return self.converter or RowConverter( self.database, - has_header=self.has_header) + has_header=self.has_header, + sample_size=self.sample_size) def analyze_csv(self): converter = self.get_converter() header, rows = converter.extract_rows( self.filename, **self.reader_kwargs) - self.fields = converter.analyze(rows) + if rows: + self.fields = converter.analyze(rows) + else: + self.fields = [converter.default.field() for _ in header] if not self.field_names: self.field_names = map(self.clean_field_name, header) @@ -225,7 +235,7 @@ def load(self): return ModelClass def load_csv(database, filename, fields=None, field_names=None, has_header=True, - converter=None, **reader_kwargs): + sample_size=10, converter=None, **reader_kwargs): loader = Loader(database, filename, fields, field_names, has_header, - converter, **reader_kwargs) + sample_size, converter, **reader_kwargs) return loader.load() From 9b664a39214dd38df911f6dff2e39da6ea539401 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Thu, 7 Nov 2013 05:27:42 -0600 Subject: [PATCH 14/39] Allow a database *or* model to be specified. --- playhouse/csv_loader.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/playhouse/csv_loader.py b/playhouse/csv_loader.py index dfab4fcb7..45f5c512c 100644 --- a/playhouse/csv_loader.py +++ b/playhouse/csv_loader.py @@ -21,6 +21,7 @@ from collections import OrderedDict from peewee import * +from peewee import Database class RowConverter(object): @@ -153,7 +154,7 @@ class Loader(object): Load the contents of a CSV file into a database and return a model class suitable for working with the CSV data. - :param database: a peewee Database instance. + :param db_or_model: a peewee Database instance or a Model class. :param str filename: the filename of the CSV file. :param list fields: A list of peewee Field() instances appropriate to the values in the CSV file. @@ -162,10 +163,9 @@ class Loader(object): :param converter: A RowConverter instance to use. :param reader_kwargs: Arbitrary arguments to pass to the CSV reader. """ - def __init__(self, database, filename, fields=None, field_names=None, + def __init__(self, db_or_model, filename, fields=None, field_names=None, has_header=True, sample_size=10, converter=None, **reader_kwargs): - self.database = database self.filename = filename self.fields = fields self.field_names = field_names @@ -174,6 +174,19 @@ def __init__(self, database, filename, fields=None, field_names=None, self.converter = converter self.reader_kwargs = reader_kwargs + if isinstance(db_or_model, Database): + self.database = db_or_model + self.model = None + else: + self.model = db_or_model + self.database = self.model._meta.database + self.fields = self.model._meta.get_fields() + self.field_names = self.model._meta.get_field_names() + # If using an auto-incrementing primary key, ignore it. + if self.model._meta.auto_increment: + self.fields = self.fields[1:] + self.field_names = self.field_names[1:] + def clean_field_name(self, s): return re.sub('[^a-z0-9]+', '_', s.lower()) @@ -199,7 +212,9 @@ def get_reader(self): fh = open(self.filename, 'r') return csv.reader(fh, **self.reader_kwargs) - def model_class(self, field_names, fields): + def get_model_class(self, field_names, fields): + if self.model: + return self.model model_name = os.path.splitext(os.path.basename(self.filename))[0] attrs = dict(zip(field_names, fields)) attrs['_auto_pk'] = PrimaryKeyField() @@ -219,7 +234,7 @@ def load(self): elif self.has_header: reader.next() - ModelClass = self.model_class(self.field_names, self.fields) + ModelClass = self.get_model_class(self.field_names, self.fields) with self.database.transaction(): ModelClass.create_table(True) @@ -227,15 +242,14 @@ def load(self): insert = {} for field_name, value in zip(self.field_names, row): if value: - value = value.decode('utf-8') - insert[field_name] = value + insert[field_name] = value.decode('utf-8') if insert: ModelClass.insert(**insert).execute() return ModelClass -def load_csv(database, filename, fields=None, field_names=None, has_header=True, - sample_size=10, converter=None, **reader_kwargs): - loader = Loader(database, filename, fields, field_names, has_header, +def load_csv(db_or_model, filename, fields=None, field_names=None, + has_header=True, sample_size=10, converter=None, **reader_kwargs): + loader = Loader(db_or_model, filename, fields, field_names, has_header, sample_size, converter, **reader_kwargs) return loader.load() From b0d422d32051bb3e497c1474da630e7be910596e Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Thu, 7 Nov 2013 05:38:45 -0600 Subject: [PATCH 15/39] Allow specifying database table. --- playhouse/csv_loader.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/playhouse/csv_loader.py b/playhouse/csv_loader.py index 45f5c512c..79305378a 100644 --- a/playhouse/csv_loader.py +++ b/playhouse/csv_loader.py @@ -165,7 +165,7 @@ class Loader(object): """ def __init__(self, db_or_model, filename, fields=None, field_names=None, has_header=True, sample_size=10, converter=None, - **reader_kwargs): + db_table=None, **reader_kwargs): self.filename = filename self.fields = fields self.field_names = field_names @@ -177,9 +177,11 @@ def __init__(self, db_or_model, filename, fields=None, field_names=None, if isinstance(db_or_model, Database): self.database = db_or_model self.model = None + self.db_table = os.path.splitext(os.path.basename(filename))[0] else: self.model = db_or_model self.database = self.model._meta.database + self.db_table = self.model._meta.db_table self.fields = self.model._meta.get_fields() self.field_names = self.model._meta.get_field_names() # If using an auto-incrementing primary key, ignore it. @@ -215,11 +217,11 @@ def get_reader(self): def get_model_class(self, field_names, fields): if self.model: return self.model - model_name = os.path.splitext(os.path.basename(self.filename))[0] attrs = dict(zip(field_names, fields)) attrs['_auto_pk'] = PrimaryKeyField() - klass = type(model_name.title(), (Model,), attrs) + klass = type(self.db_table.title(), (Model,), attrs) klass._meta.database = self.database + klass._meta.db_table = self.db_table return klass def load(self): @@ -249,7 +251,8 @@ def load(self): return ModelClass def load_csv(db_or_model, filename, fields=None, field_names=None, - has_header=True, sample_size=10, converter=None, **reader_kwargs): + has_header=True, sample_size=10, converter=None, + db_table=None, **reader_kwargs): loader = Loader(db_or_model, filename, fields, field_names, has_header, - sample_size, converter, **reader_kwargs) + sample_size, converter, db_table, **reader_kwargs) return loader.load() From 89a66d3171224bbda67ef3b66f7f5ea8a0de8b42 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Thu, 7 Nov 2013 05:41:01 -0600 Subject: [PATCH 16/39] Doc updates. --- playhouse/csv_loader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/playhouse/csv_loader.py b/playhouse/csv_loader.py index 79305378a..b1b7530e5 100644 --- a/playhouse/csv_loader.py +++ b/playhouse/csv_loader.py @@ -160,7 +160,11 @@ class Loader(object): the values in the CSV file. :param list field_names: A list of names to use for the fields. :param bool has_header: Whether the first row of the CSV file is a header. + :param int sample_size: Number of rows to introspect if fields are not + defined. :param converter: A RowConverter instance to use. + :param str db_table: Name of table to store data in (if not specified, the + table name will be derived from the CSV filename). :param reader_kwargs: Arbitrary arguments to pass to the CSV reader. """ def __init__(self, db_or_model, filename, fields=None, field_names=None, @@ -256,3 +260,4 @@ def load_csv(db_or_model, filename, fields=None, field_names=None, loader = Loader(db_or_model, filename, fields, field_names, has_header, sample_size, converter, db_table, **reader_kwargs) return loader.load() +load_csv.__doc__ = Loader.__doc__ From 845fb38c5ec160808fb09298096bd278650c523e Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Thu, 7 Nov 2013 05:53:49 -0600 Subject: [PATCH 17/39] Make sure to *use* the db table. --- playhouse/csv_loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/playhouse/csv_loader.py b/playhouse/csv_loader.py index b1b7530e5..66d5fa441 100644 --- a/playhouse/csv_loader.py +++ b/playhouse/csv_loader.py @@ -181,7 +181,8 @@ def __init__(self, db_or_model, filename, fields=None, field_names=None, if isinstance(db_or_model, Database): self.database = db_or_model self.model = None - self.db_table = os.path.splitext(os.path.basename(filename))[0] + self.db_table = (db_table or + os.path.splitext(os.path.basename(filename))[0]) else: self.model = db_or_model self.database = self.model._meta.database From f62938eb2b7fd547c6fd58f6168c43c08a1db938 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Thu, 7 Nov 2013 08:30:28 -0600 Subject: [PATCH 18/39] Starting tests for csv loader. --- playhouse/csv_loader.py | 59 ++++++++++++++++----------------- playhouse/tests_csv_loader.py | 62 +++++++++++++++++++++++++++++++++++ runtests.py | 4 +++ 3 files changed, 95 insertions(+), 30 deletions(-) create mode 100644 playhouse/tests_csv_loader.py diff --git a/playhouse/csv_loader.py b/playhouse/csv_loader.py index 66d5fa441..32062960f 100644 --- a/playhouse/csv_loader.py +++ b/playhouse/csv_loader.py @@ -19,12 +19,21 @@ import os import re from collections import OrderedDict +from contextlib import contextmanager from peewee import * from peewee import Database -class RowConverter(object): +class _CSVReader(object): + @contextmanager + def get_reader(self, filename, **reader_kwargs): + fh = open(filename, 'r') + reader = csv.reader(fh, **reader_kwargs) + yield reader + fh.close() + +class RowConverter(_CSVReader): """ Simple introspection utility to convert a CSV file into a list of headers and column types. @@ -82,10 +91,6 @@ def is_datetime(self, value): def is_date(self, value): return self.matches_date(value, self.date_formats) - @field(BooleanField, default=False) - def is_boolean(self, value): - return value.lower() in ('t', 'f', 'true', 'false') - @field(BareField, default='') def default(self, value): return True @@ -102,8 +107,7 @@ def extract_rows(self, filename, **reader_kwargs): """ rows = [] rows_to_read = self.sample_size - with open(filename) as fh: - reader = csv.reader(fh, **reader_kwargs) + with self.get_reader(filename, **reader_kwargs) as reader: if self.has_header: rows_to_read += 1 for i, row in enumerate(reader): @@ -123,7 +127,6 @@ def get_checks(self): self.is_datetime, self.is_integer, self.is_float, - self.is_boolean, self.default] def analyze(self, rows): @@ -149,7 +152,7 @@ def analyze(self, rows): return column_types -class Loader(object): +class Loader(_CSVReader): """ Load the contents of a CSV file into a database and return a model class suitable for working with the CSV data. @@ -215,10 +218,6 @@ def analyze_csv(self): if not self.field_names: self.field_names = map(self.clean_field_name, header) - def get_reader(self): - fh = open(self.filename, 'r') - return csv.reader(fh, **self.reader_kwargs) - def get_model_class(self, field_names, fields): if self.model: return self.model @@ -235,23 +234,23 @@ def load(self): if not self.field_names and not self.has_header: self.field_names = ['field_%s' for i in range(len(self.fields))] - reader = self.get_reader() - if not self.field_names: - self.field_names = map(self.clean_field_name, reader.next()) - elif self.has_header: - reader.next() - - ModelClass = self.get_model_class(self.field_names, self.fields) - - with self.database.transaction(): - ModelClass.create_table(True) - for row in reader: - insert = {} - for field_name, value in zip(self.field_names, row): - if value: - insert[field_name] = value.decode('utf-8') - if insert: - ModelClass.insert(**insert).execute() + with self.get_reader(self.filename, **self.reader_kwargs) as reader: + if not self.field_names: + self.field_names = map(self.clean_field_name, reader.next()) + elif self.has_header: + reader.next() + + ModelClass = self.get_model_class(self.field_names, self.fields) + + with self.database.transaction(): + ModelClass.create_table(True) + for row in reader: + insert = {} + for field_name, value in zip(self.field_names, row): + if value: + insert[field_name] = value.decode('utf-8') + if insert: + ModelClass.insert(**insert).execute() return ModelClass diff --git a/playhouse/tests_csv_loader.py b/playhouse/tests_csv_loader.py new file mode 100644 index 000000000..9c77659e1 --- /dev/null +++ b/playhouse/tests_csv_loader.py @@ -0,0 +1,62 @@ +import csv +import unittest +from contextlib import contextmanager +from datetime import date +from StringIO import StringIO +from textwrap import dedent + +from playhouse.csv_loader import * + + +class TestRowConverter(RowConverter): + @contextmanager + def get_reader(self, csv_data, **reader_kwargs): + reader = csv.reader(StringIO(csv_data), **reader_kwargs) + yield reader + +class TestLoader(Loader): + @contextmanager + def get_reader(self, csv_data, **reader_kwargs): + reader = csv.reader(StringIO(csv_data), **reader_kwargs) + yield reader + + def get_converter(self): + return self.converter or TestRowConverter( + self.database, + has_header=self.has_header, + sample_size=self.sample_size) + +db = SqliteDatabase(':memory:') + +class TestCSVConversion(unittest.TestCase): + header = 'id,name,dob,salary,is_admin' + simple = '10,"F1 L1",1983-01-01,10000,t' + float_sal = '20,"F2 L2",1983-01-02,20000.5,f' + only_name = ',"F3 L3",,,' + mismatch = 'foo,F4 L4,dob,sal,x' + + def build_csv(self, *lines): + return '\r\n'.join(lines) + + def load(self, *lines, **loader_kwargs): + csv = self.build_csv(*lines) + loader_kwargs['filename'] = csv + loader_kwargs.setdefault('db_table', 'csv_test') + loader_kwargs.setdefault('db_or_model', db) + return TestLoader(**loader_kwargs).load() + + def assertData(self, ModelClass, expected): + query = ModelClass.select().order_by(ModelClass.name).tuples() + self.assertEqual([row[1:] for row in query], expected) + + def test_defaults(self): + ModelClass = self.load( + self.header, + self.simple, + self.float_sal, + self.only_name) + self.assertData(ModelClass, [ + (10, 'F1 L1', date(1983, 1, 1), 10000., 't'), + (20, 'F2 L2', date(1983, 1, 2), 20000.5, 'f'), + (0, 'F3 L3', None, 0., ''), + ]) diff --git a/runtests.py b/runtests.py index 228fbaacd..9a1b22e10 100755 --- a/runtests.py +++ b/runtests.py @@ -25,6 +25,7 @@ def get_option_parser(): cases = optparse.OptionGroup(parser, 'Individual test module options') cases.add_option('--apsw', dest='apsw', default=False, action='store_true', help='apsw tests (requires apsw)') + cases.add_option('--csv', dest='csv', default=False, action='store_true', help='csv tests') cases.add_option('--gfk', dest='gfk', default=False, action='store_true', help='gfk tests') cases.add_option('--kv', dest='kv', default=False, action='store_true', help='key/value store tests') cases.add_option('--migrations', dest='migrations', default=False, action='store_true', help='migration helper tests (requires psycopg2)') @@ -50,6 +51,9 @@ def collect_modules(options): modules.append(tests_apsw) except ImportError: print_('Unable to import apsw tests, skipping') + if xtra(options.csv): + from playhouse import tests_csv_loader + modules.append(tests_csv_loader) if xtra(options.gfk): from playhouse import tests_gfk modules.append(tests_gfk) From d25c9672271dea9c3fa4c1605923b55aa4b53e9f Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Thu, 7 Nov 2013 09:37:01 -0600 Subject: [PATCH 19/39] Continuing tests --- playhouse/csv_loader.py | 5 +++-- playhouse/tests_csv_loader.py | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/playhouse/csv_loader.py b/playhouse/csv_loader.py index 32062960f..ab0ce4f09 100644 --- a/playhouse/csv_loader.py +++ b/playhouse/csv_loader.py @@ -117,7 +117,7 @@ def extract_rows(self, filename, **reader_kwargs): if self.has_header: header, rows = rows[0], rows[1:] else: - header = ['field_%d' for i in range(len(rows[0]))] + header = ['field_%d' % i for i in range(len(rows[0]))] return header, rows def get_checks(self): @@ -232,7 +232,8 @@ def load(self): if not self.fields: self.analyze_csv() if not self.field_names and not self.has_header: - self.field_names = ['field_%s' for i in range(len(self.fields))] + self.field_names = [ + 'field_%d' % i for i in range(len(self.fields))] with self.get_reader(self.filename, **self.reader_kwargs) as reader: if not self.field_names: diff --git a/playhouse/tests_csv_loader.py b/playhouse/tests_csv_loader.py index 9c77659e1..b5ec5bc43 100644 --- a/playhouse/tests_csv_loader.py +++ b/playhouse/tests_csv_loader.py @@ -35,6 +35,9 @@ class TestCSVConversion(unittest.TestCase): only_name = ',"F3 L3",,,' mismatch = 'foo,F4 L4,dob,sal,x' + def setUp(self): + db.execute_sql('drop table if exists csv_test;') + def build_csv(self, *lines): return '\r\n'.join(lines) @@ -46,7 +49,8 @@ def load(self, *lines, **loader_kwargs): return TestLoader(**loader_kwargs).load() def assertData(self, ModelClass, expected): - query = ModelClass.select().order_by(ModelClass.name).tuples() + name_field = ModelClass._meta.get_fields()[2] + query = ModelClass.select().order_by(name_field).tuples() self.assertEqual([row[1:] for row in query], expected) def test_defaults(self): @@ -60,3 +64,15 @@ def test_defaults(self): (20, 'F2 L2', date(1983, 1, 2), 20000.5, 'f'), (0, 'F3 L3', None, 0., ''), ]) + + def test_no_header(self): + ModelClass = self.load( + self.simple, + self.float_sal, + field_names=['f1', 'f2', 'f3', 'f4', 'f5'], + has_header=False) + self.assertData(ModelClass, [ + (10, 'F1 L1', date(1983, 1, 1), 10000., 't'), + (20, 'F2 L2', date(1983, 1, 2), 20000.5, 'f')]) + self.assertEqual(ModelClass._meta.get_field_names(), [ + '_auto_pk', 'f1', 'f2', 'f3', 'f4', 'f5']) From 933b166b7d50dcf26f98d6e893f46e9ccb58722d Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Thu, 7 Nov 2013 11:07:18 -0600 Subject: [PATCH 20/39] Adding more tests. --- playhouse/tests_csv_loader.py | 38 ++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/playhouse/tests_csv_loader.py b/playhouse/tests_csv_loader.py index b5ec5bc43..75aa63973 100644 --- a/playhouse/tests_csv_loader.py +++ b/playhouse/tests_csv_loader.py @@ -71,8 +71,44 @@ def test_no_header(self): self.float_sal, field_names=['f1', 'f2', 'f3', 'f4', 'f5'], has_header=False) + self.assertEqual(ModelClass._meta.get_field_names(), [ + '_auto_pk', 'f1', 'f2', 'f3', 'f4', 'f5']) self.assertData(ModelClass, [ (10, 'F1 L1', date(1983, 1, 1), 10000., 't'), (20, 'F2 L2', date(1983, 1, 2), 20000.5, 'f')]) + + def test_no_header_no_fieldnames(self): + ModelClass = self.load( + self.simple, + self.float_sal, + has_header=False) self.assertEqual(ModelClass._meta.get_field_names(), [ - '_auto_pk', 'f1', 'f2', 'f3', 'f4', 'f5']) + '_auto_pk', 'field_0', 'field_1', 'field_2', 'field_3', 'field_4']) + + def test_mismatch_types(self): + ModelClass = self.load( + self.header, + self.simple, + self.mismatch) + self.assertData(ModelClass, [ + ('10', 'F1 L1', '1983-01-01', '10000', 't'), + ('foo', 'F4 L4', 'dob', 'sal', 'x')]) + + def test_fields(self): + fields = [ + IntegerField(), + CharField(), + DateField(), + FloatField(), + CharField()] + ModelClass = self.load( + self.header, + self.simple, + self.float_sal, + fields=fields) + self.assertEqual( + map(type, fields), + map(type, ModelClass._meta.get_fields()[1:])) + self.assertData(ModelClass, [ + (10, 'F1 L1', date(1983, 1, 1), 10000., 't'), + (20, 'F2 L2', date(1983, 1, 2), 20000.5, 'f')]) From a67f99f84e32b543bc7c08db4a3ac7ff3488a337 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Fri, 8 Nov 2013 08:47:16 -0600 Subject: [PATCH 21/39] Docs for csv loader. --- docs/peewee/playhouse.rst | 82 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index a65446315..0242b622b 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -26,6 +26,7 @@ As well as tools for working with databases: * :ref:`pwiz` * :ref:`migrate` +* :ref:`csv_loader` * :ref:`test_utils` @@ -1310,6 +1311,87 @@ Renaming a table migrator.rename_table(Story, 'stories') +.. _csv_loader: + +CSV Loader +---------- + +This module contains helpers for loading CSV data into a database. CSV files can +be introspected to generate an appropriate model class for working with the data. +This makes it really easy to explore the data in a CSV file using Peewee and SQL. + +Here is how you would load a CSV file into an in-memory SQLite database. The +call to :py:func:`load_csv` returns a :py:class:`Model` instance suitable for +working with the CSV data: + +.. code-block:: python + + from playhouse.csv_loader import load_csv + db = SqliteDatabase(':memory:') + ZipToTZ = load_csv(db, 'zip_to_tz.csv') + +Now we can run queries using the new model. + +.. code-block:: pycon + + # Get the timezone for a zipcode. + >>> ZipToTZ.get(ZipToTZ.zip == 66047).timezone + 'US/Central' + + # Get all the zipcodes for my town. + >>> [row.zip for row in ZipToTZ.select().where( + ... (ZipToTZ.city == 'Lawrence') && (ZipToTZ.state == 'KS'))] + [66044, 66045, 66046, 66047, 66049] + +For more information and examples check out this `blog post `_. + + +CSV Loader API +^^^^^^^^^^^^^^ + +.. py:function:: load_csv(db_or_model, filename[, fields=None[, field_names=None[, has_header=True[, sample_size=10[, converter=None[, db_table=None[, **reader_kwargs]]]]]]]) + + Load a CSV file into the provided database or model class, returning a + :py:class:`Model` suitable for working with the CSV data. + + :param db_or_model: Either a :py:class:`Database` instance or a :py:class:`Model` class. If a model is not provided, one will be automatically generated for you. + :param str filename: Path of CSV file to load. + :param list fields: A list of :py:class:`Field` instances mapping to each column in the CSV. This allows you to manually specify the column types. If not provided, and a model is not provided, the field types will be determined automatically. + :param list field_names: A list of strings to use as field names for each column in the CSV. If not provided, and a model is not provided, the field names will be determined by looking at the header row of the file. If no header exists, then the fields will be given generic names. + :param bool has_header: Whether the first row is a header. + :param int sample_size: Number of rows to look at when introspecting data types. If set to ``0``, then a generic field type will be used for all fields. + :param RowConverter converter: a :py:class:`RowConverter` instance to use for introspecting the CSV. If not provided, one will be created. + :param str db_table: The name of the database table to load data into. If this value is not provided, it will be determined using the filename of the CSV file. If a model is provided, this value is ignored. + :param reader_kwargs: Arbitrary keyword arguments to pass to the ``csv.reader`` object, such as the dialect, separator, etc. + :rtype: A :py:class:`Model` suitable for querying the CSV data. + + Basic example -- field names and types will be introspected: + + .. code-block:: python + + from playhouse.csv_loader import * + db = SqliteDatabase(':memory:') + User = load_csv(db, 'users.csv') + + Using a pre-defined model: + + .. code-block:: python + + class ZipToTZ(Model): + zip = IntegerField() + timezone = CharField() + + load_csv(ZipToTZ, 'zip_to_tz.csv') + + Specifying fields: + + .. code-block:: python + + fields = [DecimalField(), IntegerField(), IntegerField(), DateField()] + field_names = ['amount', 'from_acct', 'to_acct', 'timestamp'] + Payments = load_csv(db, 'payments.csv', fields=fields, field_names=field_names, has_header=False) + + .. _test_utils: Test Utils From 186475cc8f48cf58b38f5fe9266068da97660aa0 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Sat, 9 Nov 2013 11:02:41 -0600 Subject: [PATCH 22/39] Remove with statement import, fixes #245 --- peewee.py | 1 - 1 file changed, 1 deletion(-) diff --git a/peewee.py b/peewee.py index a7ed2f281..d41285631 100644 --- a/peewee.py +++ b/peewee.py @@ -7,7 +7,6 @@ # ///' # // # ' -from __future__ import with_statement import datetime import decimal import logging From a6884c52cecf0ef7a263a9894f0268a3fd01f69a Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Mon, 11 Nov 2013 08:50:54 -0600 Subject: [PATCH 23/39] Handle square brackets in sqlite. --- pwiz.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pwiz.py b/pwiz.py index a86a25a0e..9251f576a 100755 --- a/pwiz.py +++ b/pwiz.py @@ -250,7 +250,7 @@ class SqliteIntrospector(Introspector): 'text': TextField, 'time': TimeField, } - re_foreign_key = '"?(.+?)"?\s+.+\s+references (.*) \(["|]?(.*)["|]?\)' + re_foreign_key = '["\[]?(.+?)["\]]?\s+.+\s+references ["\[]?(.+?)["\]]? \(["|\[]?(.+?)["|\]]?\)' re_varchar = r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$' def get_conn_class(self): From 37e6880bc4fe37843baf4131e2c6a248b1d282c7 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Mon, 11 Nov 2013 22:32:51 -0600 Subject: [PATCH 24/39] Removing future in tests.py, refs #245 --- tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests.py b/tests.py index 2b3164fce..790f8ee29 100644 --- a/tests.py +++ b/tests.py @@ -1,6 +1,5 @@ # encoding=utf-8 -from __future__ import with_statement import datetime import decimal import logging From 869aa2ce8e4136c494e0796d9accf4189554527c Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 13 Nov 2013 18:04:33 -0600 Subject: [PATCH 25/39] Adding django integration. --- playhouse/djpeewee.py | 121 ++++++++++++++++++++++++++++ playhouse/tests_djpeewee.py | 152 ++++++++++++++++++++++++++++++++++++ runtests.py | 4 + 3 files changed, 277 insertions(+) create mode 100644 playhouse/djpeewee.py create mode 100644 playhouse/tests_djpeewee.py diff --git a/playhouse/djpeewee.py b/playhouse/djpeewee.py new file mode 100644 index 000000000..0bbb5e024 --- /dev/null +++ b/playhouse/djpeewee.py @@ -0,0 +1,121 @@ +""" +Simple translation of Django model classes to peewee model classes. +""" +from functools import partial +import logging + +from peewee import * + +logger = logging.getLogger('peewee.playhouse.djpeewee') + +class DjangoTranslator(object): + def __init__(self): + self._field_map = self.get_django_field_map() + + def get_django_field_map(self): + from django.db.models import fields as djf + return [ + (djf.AutoField, PrimaryKeyField), + (djf.BigIntegerField, BigIntegerField), + #(djf.BinaryField, BlobField), + (djf.BooleanField, BooleanField), + (djf.CharField, CharField), + (djf.DateTimeField, DateTimeField), # Extends DateField. + (djf.DateField, DateField), + (djf.DecimalField, DecimalField), + (djf.FilePathField, CharField), + (djf.FloatField, FloatField), + (djf.IntegerField, IntegerField), + (djf.NullBooleanField, partial(BooleanField, null=True)), + (djf.TextField, TextField), + (djf.TimeField, TimeField), + (djf.related.ForeignKey, ForeignKeyField), + ] + + def _translate_model(self, model, mapping): + from django.db.models import fields as djf + options = model._meta + attrs = {} + # Sort fields such that nullable fields appear last. + field_key = lambda field: (field.null and 1 or 0, field) + for model_field in sorted(options.fields, key=field_key): + # Find the appropriate peewee field class. + converted = None + for django_field, peewee_field in self._field_map: + if isinstance(model_field, django_field): + converted = peewee_field + break + + # Special-case ForeignKey fields. + if converted is ForeignKeyField: + related_model = model_field.rel.to + model_name = related_model._meta.object_name + # If we haven't processed the related model yet, do so now. + if model_name not in mapping: + mapping[model_name] = None # Avoid endless recursion. + self._translate_model(related_model, mapping) + if mapping[model_name] is None: + # Put an integer field here. + logger.warn('Cycle detected: %s: %s', + model_field.name, model_name) + mapping[model_name] = IntegerField( + db_column=model_field.get_attname()) + else: + related_name = model_field.related_query_name() + if related_name.endswith('+'): + related_name = '__ignore_%s_%s_%s' % ( + model_field.name, + model_name, + related_name.strip('+')) + + attrs[model_field.name] = ForeignKeyField( + mapping[model_name], + related_name=related_name) + + elif converted: + attrs[model_field.name] = converted() + + klass = type(options.object_name, (Model,), attrs) + klass._meta.db_table = options.db_table + mapping[options.object_name] = klass + + # Load up many-to-many relationships. + for many_to_many in options.many_to_many: + if not isinstance(many_to_many, djf.related.ManyToManyField): + continue + self._translate_model(many_to_many.rel.through, mapping) + + + def translate_models(self, *models): + """ + Generate a group of peewee models analagous to the provided Django models + for the purposes of creating queries. + + :param model: A Django model class. + :returns: A dictionary mapping model names to peewee model classes. + :rtype: dict + + Example:: + + # Map Django models to peewee models. Foreign keys and M2M will be + # traversed as well. + peewee = translate(Account) + + # Generate query using peewee. + PUser = peewee['User'] + PAccount = peewee['Account'] + query = PUser.select().join(PAccount).where(PAccount.acct_type == 'foo') + + # Django raw query. + users = User.objects.raw(*query.sql()) + """ + mapping = {} + for model in models: + self._translate_model(model, mapping) + return mapping + +try: + import django + translate = DjangoTranslator().translate_models +except ImportError: + pass diff --git a/playhouse/tests_djpeewee.py b/playhouse/tests_djpeewee.py new file mode 100644 index 000000000..3f08d04f7 --- /dev/null +++ b/playhouse/tests_djpeewee.py @@ -0,0 +1,152 @@ +import unittest + +from peewee import * +from peewee import print_ +try: + import django +except ImportError: + django = None + + +if django is not None: + from django.conf import settings + settings.configure( + DATABASES={ + 'default': { + 'engine': 'django.db.backends.sqlite3', + 'name': ':memory:'}}, + ) + from django.db import models + from playhouse.djpeewee import translate + + # Django model definitions. + class Simple(models.Model): + char_field = models.CharField(max_length=1) + int_field = models.IntegerField() + + class User(models.Model): + username = models.CharField(max_length=255) + + class Meta: + db_table = 'user_tbl' + + class Post(models.Model): + author = models.ForeignKey(User, related_name='posts') + content = models.TextField() + + class Comment(models.Model): + post = models.ForeignKey(Post, related_name='comments') + commenter = models.ForeignKey(User, related_name='comments') + comment = models.TextField() + + class Tag(models.Model): + tag = models.CharField() + posts = models.ManyToManyField(Post) + + class TestDjPeewee(unittest.TestCase): + def assertFields(self, model, expected): + zipped = zip(model._meta.get_fields(), expected) + for (model_field, (name, field_type)) in zipped: + self.assertEqual(model_field.name, name) + self.assertEqual(type(model_field), field_type) + + def test_simple(self): + P = translate(Simple) + self.assertEqual(P.keys(), ['Simple']) + self.assertFields(P['Simple'], [ + ('id', PrimaryKeyField), + ('char_field', CharField), + ('int_field', IntegerField), + ]) + + def test_graph(self): + P = translate(User, Tag, Comment) + self.assertEqual(sorted(P.keys()), [ + 'Comment', + 'Post', + 'Tag', + 'Tag_posts', + 'User']) + + # Test the models that were found. + user = P['User'] + self.assertFields(user, [ + ('id', PrimaryKeyField), + ('username', CharField)]) + self.assertEqual(user.posts.rel_model, P['Post']) + self.assertEqual(user.comments.rel_model, P['Comment']) + + post = P['Post'] + self.assertFields(post, [ + ('id', PrimaryKeyField), + ('author', ForeignKeyField), + ('content', TextField)]) + self.assertEqual(post.comments.rel_model, P['Comment']) + + comment = P['Comment'] + self.assertFields(comment, [ + ('id', PrimaryKeyField), + ('post', ForeignKeyField), + ('commenter', ForeignKeyField), + ('comment', TextField)]) + + tag = P['Tag'] + self.assertFields(tag, [ + ('id', PrimaryKeyField), + ('tag', CharField)]) + + thru = P['Tag_posts'] + self.assertFields(thru, [ + ('id', PrimaryKeyField), + ('tag', ForeignKeyField), + ('post', ForeignKeyField)]) + + def test_fk_query(self): + trans = translate(User, Post, Comment, Tag) + U = trans['User'] + P = trans['Post'] + C = trans['Comment'] + + query = (U.select() + .join(P) + .join(C) + .where(C.comment == 'test')) + sql, params = query.sql() + self.assertEqual( + sql, + 'SELECT t1."id", t1."username" FROM "user_tbl" AS t1 ' + 'INNER JOIN "playhouse_post" AS t2 ' + 'ON (t1."id" = t2."author_id") ' + 'INNER JOIN "playhouse_comment" AS t3 ' + 'ON (t2."id" = t3."post_id") WHERE (t3."comment" = ?)') + self.assertEqual(params, ['test']) + + def test_m2m_query(self): + trans = translate(Post, Tag) + P = trans['Post'] + U = trans['User'] + T = trans['Tag'] + TP = trans['Tag_posts'] + + query = (P.select(P, U) + .join(U) + .switch(P) + .join(TP) + .join(T) + .where(T.tag == 'test')) + sql, params = query.sql() + self.assertEqual( + sql, + 'SELECT t1."id", t1."author_id", t1."content", ' + 't4."id", t4."username" FROM "playhouse_post" AS t1 ' + 'INNER JOIN "user_tbl" AS t4 ' + 'ON (t1."author_id" = t4."id") ' + 'INNER JOIN "playhouse_tag_posts" AS t2 ' + 'ON (t1."id" = t2."post_id") ' + 'INNER JOIN "playhouse_tag" AS t3 ' + 'ON (t2."tag_id" = t3."id") WHERE (t3."tag" = ?)') + self.assertEqual(params, ['test']) + + +else: + print_('Skipping djpeewee tests, Django not found.') diff --git a/runtests.py b/runtests.py index 9a1b22e10..954fd601b 100755 --- a/runtests.py +++ b/runtests.py @@ -26,6 +26,7 @@ def get_option_parser(): cases = optparse.OptionGroup(parser, 'Individual test module options') cases.add_option('--apsw', dest='apsw', default=False, action='store_true', help='apsw tests (requires apsw)') cases.add_option('--csv', dest='csv', default=False, action='store_true', help='csv tests') + cases.add_option('--djpeewee', dest='djpeewee', default=False, action='store_true', help='djpeewee tests') cases.add_option('--gfk', dest='gfk', default=False, action='store_true', help='gfk tests') cases.add_option('--kv', dest='kv', default=False, action='store_true', help='key/value store tests') cases.add_option('--migrations', dest='migrations', default=False, action='store_true', help='migration helper tests (requires psycopg2)') @@ -54,6 +55,9 @@ def collect_modules(options): if xtra(options.csv): from playhouse import tests_csv_loader modules.append(tests_csv_loader) + if xtra(options.djpeewee): + from playhouse import tests_djpeewee + modules.append(tests_djpeewee) if xtra(options.gfk): from playhouse import tests_gfk modules.append(tests_gfk) From 6a598a0b5de1ff5842f2df40cee649396f77417a Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 13 Nov 2013 18:37:43 -0600 Subject: [PATCH 26/39] Fix re-defining classes, one new test. --- playhouse/djpeewee.py | 7 +++++-- playhouse/tests_djpeewee.py | 35 ++++++++++++++++++++++++++++------- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/playhouse/djpeewee.py b/playhouse/djpeewee.py index 0bbb5e024..690658614 100644 --- a/playhouse/djpeewee.py +++ b/playhouse/djpeewee.py @@ -35,6 +35,9 @@ def get_django_field_map(self): def _translate_model(self, model, mapping): from django.db.models import fields as djf options = model._meta + if options.object_name in mapping: + return + attrs = {} # Sort fields such that nullable fields appear last. field_key = lambda field: (field.null and 1 or 0, field) @@ -63,9 +66,9 @@ def _translate_model(self, model, mapping): else: related_name = model_field.related_query_name() if related_name.endswith('+'): - related_name = '__ignore_%s_%s_%s' % ( + related_name = '__%s:%s:%s' % ( + options, model_field.name, - model_name, related_name.strip('+')) attrs[model_field.name] = ForeignKeyField( diff --git a/playhouse/tests_djpeewee.py b/playhouse/tests_djpeewee.py index 3f08d04f7..306635d0b 100644 --- a/playhouse/tests_djpeewee.py +++ b/playhouse/tests_djpeewee.py @@ -1,3 +1,4 @@ +from datetime import timedelta import unittest from peewee import * @@ -43,6 +44,14 @@ class Tag(models.Model): tag = models.CharField() posts = models.ManyToManyField(Post) + class Event(models.Model): + start_time = models.DateTimeField() + end_time = models.DateTimeField() + title = models.CharField() + + class Meta: + db_table = 'events_tbl' + class TestDjPeewee(unittest.TestCase): def assertFields(self, model, expected): zipped = zip(model._meta.get_fields(), expected) @@ -128,25 +137,37 @@ def test_m2m_query(self): T = trans['Tag'] TP = trans['Tag_posts'] - query = (P.select(P, U) - .join(U) - .switch(P) + query = (P.select() .join(TP) .join(T) .where(T.tag == 'test')) sql, params = query.sql() self.assertEqual( sql, - 'SELECT t1."id", t1."author_id", t1."content", ' - 't4."id", t4."username" FROM "playhouse_post" AS t1 ' - 'INNER JOIN "user_tbl" AS t4 ' - 'ON (t1."author_id" = t4."id") ' + 'SELECT t1."id", t1."author_id", t1."content" ' + 'FROM "playhouse_post" AS t1 ' 'INNER JOIN "playhouse_tag_posts" AS t2 ' 'ON (t1."id" = t2."post_id") ' 'INNER JOIN "playhouse_tag" AS t3 ' 'ON (t2."tag_id" = t3."id") WHERE (t3."tag" = ?)') self.assertEqual(params, ['test']) + def test_docs_example(self): + # The docs don't lie. + PEvent = translate(Event)['Event'] + hour = timedelta(hours=1) + query = (PEvent + .select() + .where( + (PEvent.end_time - PEvent.start_time) > hour)) + sql, params = query.sql() + self.assertEqual( + sql, + 'SELECT t1."id", t1."start_time", t1."end_time", t1."title" ' + 'FROM "events_tbl" AS t1 ' + 'WHERE ((t1."end_time" - t1."start_time") > ?)') + self.assertEqual(params, [hour]) + else: print_('Skipping djpeewee tests, Django not found.') From 86157f268dc2e421d94efcce3445e4927f50d292 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 13 Nov 2013 18:42:28 -0600 Subject: [PATCH 27/39] Fix for when model is being populated elsewhere in the stack. --- playhouse/djpeewee.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/playhouse/djpeewee.py b/playhouse/djpeewee.py index 690658614..5ac5237bb 100644 --- a/playhouse/djpeewee.py +++ b/playhouse/djpeewee.py @@ -35,7 +35,7 @@ def get_django_field_map(self): def _translate_model(self, model, mapping): from django.db.models import fields as djf options = model._meta - if options.object_name in mapping: + if mapping.get(options.object_name): return attrs = {} From 7a0ee56adc5d89d64b7f8b20bdb7b6a5ac7e5593 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 13 Nov 2013 20:58:45 -0600 Subject: [PATCH 28/39] Adding dotted lookups. --- playhouse/djpeewee.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/playhouse/djpeewee.py b/playhouse/djpeewee.py index 5ac5237bb..701326362 100644 --- a/playhouse/djpeewee.py +++ b/playhouse/djpeewee.py @@ -8,6 +8,10 @@ logger = logging.getLogger('peewee.playhouse.djpeewee') +class AttrDict(dict): + def __getattr__(self, attr): + return self[attr] + class DjangoTranslator(object): def __init__(self): self._field_map = self.get_django_field_map() @@ -112,7 +116,7 @@ def translate_models(self, *models): # Django raw query. users = User.objects.raw(*query.sql()) """ - mapping = {} + mapping = AttrDict() for model in models: self._translate_model(model, mapping) return mapping From f664295574bf00884335507d5bf54ba2b8df2227 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 13 Nov 2013 21:14:05 -0600 Subject: [PATCH 29/39] Adding docs on djpeewee --- docs/peewee/playhouse.rst | 102 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index 0242b622b..5e7f33594 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -16,6 +16,7 @@ specific functionality: Modules which expose higher-level python constructs: +* :ref:`djpeewee` * :ref:`gfk` * :ref:`kv` * :ref:`proxy` @@ -765,6 +766,107 @@ sqlite_ext API notes The above code is exactly what the :py:meth:`match` function provides. +.. _djpeewee: + +Django Integration +------------------ + +The Django ORM provides a very high-level abstraction over SQL and is +`very limited in terms of flexibility or expressiveness `_. I +wrote a `blog post `_ +describing my search for a "missing link" between Django's ORM and the SQL it +generates, concluding that no such layer exists. The ``djpeewee`` module attempts +to provide an easy-to-use, structured layer for generating SQL queries for use +with Django's ORM. + +A couple use-cases might be: + +* Joining on fields that are not related by foreign key (for example UUID fields). +* Performing aggregate queries on calculated values. +* Features that Django does not support such as ``CASE`` statements. +* Utilize SQL functions that Django does not support, such as ``SUBSTR``. +* Replace nearly-identical SQL queries with reusable, composable data-structures. + +Below is an example of how you might use this: + +.. code-block:: python + + # Django model. + class Event(models.Model): + start_time = models.DateTimeField() + end_time = models.DateTimeField() + title = models.CharField(max_length=255) + + # Suppose we want to find all events that are longer than an hour. Django + # does not support this, but we can use peewee. + from playhouse.djpeewee import translate + P = translate(Event) + query = (P.Event + .select() + .where( + (P.Event.end_time - P.Event.start_time) > timedelta(hours=1))) + + # Now feed our peewee query into Django's `raw()` method: + sql, params = query.sql() + Event.objects.raw(sql, params) + +Foreign keys and Many-to-many relationships +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :py:func:`translate` function will recursively traverse the graph of models +and return a dictionary populated with everything it finds. Back-references are +not searched, only explicit foreign keys or many-to-manys. + +Example: + +.. code-block:: pycon + + >>> from django.contrib.auth.models import User, Group + >>> from playhouse.djpeewee import translate + >>> translate(User, Group) + {'ContentType': peewee.ContentType, + 'Group': peewee.Group, + 'Group_permissions': peewee.Group_permissions, + 'Permission': peewee.Permission, + 'User': peewee.User, + 'User_groups': peewee.User_groups, + 'User_user_permissions': peewee.User_user_permissions} + +As you can see in the example above, although only `User` and `Group` were passed +in to :py:func:`translate`, several other models which are related by foreign key +were also created. Additionally, the many-to-many "through" tables were created +as separate models since peewee does not abstract away these types of relationships. + +Using the above models it is possible to construct joins. The following example +will get all users who belong to a group that starts with the letter "A": + +.. code-block::pycon + + >>> P = translate(User, Group) + >>> query = P.User.select().join(P.User_groups).join(P.Group).where( + ... fn.Lower(fn.Substr(P.Group.name, 1, 1)) == 'a') + >>> sql, params = query.sql() + >>> print sql # formatted for legibility + SELECT t1."id", t1."password", ... + FROM "auth_user" AS t1 + INNER JOIN "auth_user_groups" AS t2 ON (t1."id" = t2."user_id") + INNER JOIN "auth_group" AS t3 ON (t2."group_id" = t3."id") + WHERE (Lower(Substr(t3."name", ?, ?)) = ?) + +djpeewee API +^^^^^^^^^^^^ + +.. py:function:: translate(*models) + + Translate the given Django models into roughly equivalent peewee models + suitable for use constructing queries. Foreign keys and many-to-many relationships + will be followed and models generated, although back references are not traversed. + + :param models: One or more Django model classes. + :returns: A dict-like object containing the generated models, but which supports + dotted-name style lookups. + + .. _gfk: Generic foreign keys From c3d695b265d4adcc6ef9a5d2d0b5c49bedcf6c9a Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 13 Nov 2013 21:16:05 -0600 Subject: [PATCH 30/39] Fix interpolation --- playhouse/djpeewee.py | 1 + playhouse/tests_djpeewee.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/playhouse/djpeewee.py b/playhouse/djpeewee.py index 701326362..034f09bd7 100644 --- a/playhouse/djpeewee.py +++ b/playhouse/djpeewee.py @@ -84,6 +84,7 @@ def _translate_model(self, model, mapping): klass = type(options.object_name, (Model,), attrs) klass._meta.db_table = options.db_table + klass._meta.database.interpolation = '%s' mapping[options.object_name] = klass # Load up many-to-many relationships. diff --git a/playhouse/tests_djpeewee.py b/playhouse/tests_djpeewee.py index 306635d0b..92e0ed109 100644 --- a/playhouse/tests_djpeewee.py +++ b/playhouse/tests_djpeewee.py @@ -127,7 +127,7 @@ def test_fk_query(self): 'INNER JOIN "playhouse_post" AS t2 ' 'ON (t1."id" = t2."author_id") ' 'INNER JOIN "playhouse_comment" AS t3 ' - 'ON (t2."id" = t3."post_id") WHERE (t3."comment" = ?)') + 'ON (t2."id" = t3."post_id") WHERE (t3."comment" = %s)') self.assertEqual(params, ['test']) def test_m2m_query(self): @@ -149,7 +149,7 @@ def test_m2m_query(self): 'INNER JOIN "playhouse_tag_posts" AS t2 ' 'ON (t1."id" = t2."post_id") ' 'INNER JOIN "playhouse_tag" AS t3 ' - 'ON (t2."tag_id" = t3."id") WHERE (t3."tag" = ?)') + 'ON (t2."tag_id" = t3."id") WHERE (t3."tag" = %s)') self.assertEqual(params, ['test']) def test_docs_example(self): @@ -165,7 +165,7 @@ def test_docs_example(self): sql, 'SELECT t1."id", t1."start_time", t1."end_time", t1."title" ' 'FROM "events_tbl" AS t1 ' - 'WHERE ((t1."end_time" - t1."start_time") > ?)') + 'WHERE ((t1."end_time" - t1."start_time") > %s)') self.assertEqual(params, [hour]) From 6421a66f775a557b4384df14537a85f6c96593be Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 13 Nov 2013 21:18:10 -0600 Subject: [PATCH 31/39] Doc formatting fix. --- docs/peewee/playhouse.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index 5e7f33594..0f8622ff4 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -840,7 +840,7 @@ as separate models since peewee does not abstract away these types of relationsh Using the above models it is possible to construct joins. The following example will get all users who belong to a group that starts with the letter "A": -.. code-block::pycon +.. code-block:: pycon >>> P = translate(User, Group) >>> query = P.User.select().join(P.User_groups).join(P.Group).where( From 529bd14e694724f0c59202bf2ec9546e35ee0fb6 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 13 Nov 2013 21:19:31 -0600 Subject: [PATCH 32/39] Correct interpolation in docs. --- docs/peewee/playhouse.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index 0f8622ff4..683973ae6 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -851,7 +851,7 @@ will get all users who belong to a group that starts with the letter "A": FROM "auth_user" AS t1 INNER JOIN "auth_user_groups" AS t2 ON (t1."id" = t2."user_id") INNER JOIN "auth_group" AS t3 ON (t2."group_id" = t3."id") - WHERE (Lower(Substr(t3."name", ?, ?)) = ?) + WHERE (Lower(Substr(t3."name", %s, %s)) = %s) djpeewee API ^^^^^^^^^^^^ From 7ca4beadce8a7d3c99b451c3bd7f18d001ce89bc Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 13 Nov 2013 21:28:37 -0600 Subject: [PATCH 33/39] Better wording. --- docs/peewee/playhouse.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index 683973ae6..9f22f3395 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -771,8 +771,8 @@ sqlite_ext API notes Django Integration ------------------ -The Django ORM provides a very high-level abstraction over SQL and is -`very limited in terms of flexibility or expressiveness `_. I +The Django ORM provides a very high-level abstraction over SQL and as a consequence is in some ways +`limited in terms of flexibility or expressiveness `_. I wrote a `blog post `_ describing my search for a "missing link" between Django's ORM and the SQL it generates, concluding that no such layer exists. The ``djpeewee`` module attempts From 5e4f26c198790994a6409458826748485cf7de8d Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 13 Nov 2013 21:30:03 -0600 Subject: [PATCH 34/39] More docs cleanups. --- docs/peewee/playhouse.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index 9f22f3395..2a76b4071 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -784,8 +784,8 @@ A couple use-cases might be: * Joining on fields that are not related by foreign key (for example UUID fields). * Performing aggregate queries on calculated values. * Features that Django does not support such as ``CASE`` statements. -* Utilize SQL functions that Django does not support, such as ``SUBSTR``. -* Replace nearly-identical SQL queries with reusable, composable data-structures. +* Utilizing SQL functions that Django does not support, such as ``SUBSTR``. +* Replacing nearly-identical SQL queries with reusable, composable data-structures. Below is an example of how you might use this: From f470ef7aaa0b988a4e20ea9591eb4e916ee725c6 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Sun, 17 Nov 2013 07:56:10 -0600 Subject: [PATCH 35/39] Adding more options. --- playhouse/djpeewee.py | 131 +++++++++++++++++++++++++++--------- playhouse/tests_djpeewee.py | 73 +++++++++++++++++++- 2 files changed, 171 insertions(+), 33 deletions(-) diff --git a/playhouse/djpeewee.py b/playhouse/djpeewee.py index 034f09bd7..d4ab8c8c1 100644 --- a/playhouse/djpeewee.py +++ b/playhouse/djpeewee.py @@ -36,48 +36,74 @@ def get_django_field_map(self): (djf.related.ForeignKey, ForeignKeyField), ] - def _translate_model(self, model, mapping): + def convert_field(self, field): + converted = None + for django_field, peewee_field in self._field_map: + if isinstance(field, django_field): + converted = peewee_field + break + return converted + + def _translate_model(self, + model, + mapping, + max_depth=None, + backrefs=False, + exclude=None): + if exclude and model in exclude: + return + + if max_depth is None: + max_depth = -1 + from django.db.models import fields as djf options = model._meta if mapping.get(options.object_name): return + mapping[options.object_name] = None attrs = {} # Sort fields such that nullable fields appear last. field_key = lambda field: (field.null and 1 or 0, field) for model_field in sorted(options.fields, key=field_key): - # Find the appropriate peewee field class. - converted = None - for django_field, peewee_field in self._field_map: - if isinstance(model_field, django_field): - converted = peewee_field - break + # Get peewee equivalent for this field type. + converted = self.convert_field(model_field) # Special-case ForeignKey fields. if converted is ForeignKeyField: - related_model = model_field.rel.to - model_name = related_model._meta.object_name - # If we haven't processed the related model yet, do so now. - if model_name not in mapping: - mapping[model_name] = None # Avoid endless recursion. - self._translate_model(related_model, mapping) - if mapping[model_name] is None: - # Put an integer field here. - logger.warn('Cycle detected: %s: %s', - model_field.name, model_name) - mapping[model_name] = IntegerField( - db_column=model_field.get_attname()) - else: - related_name = model_field.related_query_name() - if related_name.endswith('+'): - related_name = '__%s:%s:%s' % ( - options, - model_field.name, - related_name.strip('+')) + if max_depth != 0: + related_model = model_field.rel.to + model_name = related_model._meta.object_name + # If we haven't processed the related model yet, do so now. + if model_name not in mapping: + mapping[model_name] = None # Avoid endless recursion. + self._translate_model( + related_model, + mapping, + max_depth=max_depth - 1, + backrefs=backrefs, + exclude=exclude) + if mapping[model_name] is None: + # Cycle detected, put an integer field here. + logger.warn('Cycle detected: %s: %s', + model_field.name, model_name) + attrs[model_field.name] = IntegerField( + db_column=model_field.get_attname()) + else: + related_name = model_field.related_query_name() + if related_name.endswith('+'): + related_name = '__%s:%s:%s' % ( + options, + model_field.name, + related_name.strip('+')) + + attrs[model_field.name] = ForeignKeyField( + mapping[model_name], + related_name=related_name) - attrs[model_field.name] = ForeignKeyField( - mapping[model_name], - related_name=related_name) + else: + attrs[model_field.name] = IntegerField( + db_column=model_field.get_attname()) elif converted: attrs[model_field.name] = converted() @@ -87,22 +113,46 @@ def _translate_model(self, model, mapping): klass._meta.database.interpolation = '%s' mapping[options.object_name] = klass + if backrefs: + # Follow back-references for foreign keys. + for rel_obj in options.get_all_related_objects(): + if rel_obj.model._meta.object_name in mapping: + continue + self._translate_model( + rel_obj.model, + mapping, + max_depth=max_depth - 1, + backrefs=backrefs, + exclude=exclude) + # Load up many-to-many relationships. for many_to_many in options.many_to_many: if not isinstance(many_to_many, djf.related.ManyToManyField): continue - self._translate_model(many_to_many.rel.through, mapping) + self._translate_model( + many_to_many.rel.through, + mapping, + max_depth=max_depth, # Do not decrement. + backrefs=backrefs, + exclude=exclude) - def translate_models(self, *models): + def translate_models(self, *models, **options): """ Generate a group of peewee models analagous to the provided Django models for the purposes of creating queries. :param model: A Django model class. + :param options: A dictionary of options, see note below. :returns: A dictionary mapping model names to peewee model classes. :rtype: dict + Recognized options: + `recurse`: Follow foreign keys (default: True) + `max_depth`: Max depth to recurse (default: None, unlimited) + `backrefs`: Follow backrefs (default: False) + `exclude`: A list of models to exclude + Example:: # Map Django models to peewee models. Foreign keys and M2M will be @@ -118,8 +168,25 @@ def translate_models(self, *models): users = User.objects.raw(*query.sql()) """ mapping = AttrDict() + recurse = options.get('recurse', True) + max_depth = options.get('max_depth', None) + backrefs = options.get('backrefs', False) + exclude = options.get('exclude', None) + if not recurse and max_depth: + raise ValueError('Error, you cannot specify a max_depth when ' + 'recurse=False.') + elif not recurse: + max_depth = 0 + elif recurse and max_depth is None: + max_depth = -1 + for model in models: - self._translate_model(model, mapping) + self._translate_model( + model, + mapping, + max_depth=max_depth, + backrefs=backrefs, + exclude=exclude) return mapping try: diff --git a/playhouse/tests_djpeewee.py b/playhouse/tests_djpeewee.py index 92e0ed109..eeec3970f 100644 --- a/playhouse/tests_djpeewee.py +++ b/playhouse/tests_djpeewee.py @@ -52,12 +52,26 @@ class Event(models.Model): class Meta: db_table = 'events_tbl' + class Category(models.Model): + parent = models.ForeignKey('self', null=True) + + class A(models.Model): + a_field = models.IntegerField() + b = models.ForeignKey('B', null=True, related_name='as') + + class B(models.Model): + a = models.ForeignKey(A, related_name='bs') + + class C(models.Model): + b = models.ForeignKey(B, related_name='cs') + class TestDjPeewee(unittest.TestCase): def assertFields(self, model, expected): + self.assertEqual(len(model._meta.fields), len(expected)) zipped = zip(model._meta.get_fields(), expected) for (model_field, (name, field_type)) in zipped: self.assertEqual(model_field.name, name) - self.assertEqual(type(model_field), field_type) + self.assertTrue(type(model_field) is field_type) def test_simple(self): P = translate(Simple) @@ -168,6 +182,63 @@ def test_docs_example(self): 'WHERE ((t1."end_time" - t1."start_time") > %s)') self.assertEqual(params, [hour]) + def test_self_referential(self): + trans = translate(Category) + self.assertFields(trans['Category'], [ + ('id', PrimaryKeyField), + ('parent', IntegerField)]) + + def test_cycle(self): + trans = translate(A) + self.assertFields(trans['A'], [ + ('id', PrimaryKeyField), + ('a_field', IntegerField), + ('b', ForeignKeyField)]) + self.assertFields(trans['B'], [ + ('id', PrimaryKeyField), + ('a', IntegerField)]) + + trans = translate(B) + self.assertFields(trans['A'], [ + ('id', PrimaryKeyField), + ('a_field', IntegerField), + ('b', IntegerField)]) + self.assertFields(trans['B'], [ + ('id', PrimaryKeyField), + ('a', ForeignKeyField)]) + + def test_max_depth(self): + trans = translate(C, max_depth=1) + self.assertFields(trans['C'], [ + ('id', PrimaryKeyField), + ('b', ForeignKeyField)]) + self.assertFields(trans['B'], [ + ('id', PrimaryKeyField), + ('a', IntegerField)]) + + def test_exclude(self): + trans = translate(Comment, exclude=(User,)) + self.assertFields(trans['Post'], [ + ('id', PrimaryKeyField), + ('author', IntegerField), + ('content', TextField)]) + self.assertEqual( + trans['Post'].comments.rel_model, + trans['Comment']) + + self.assertFields(trans['Comment'], [ + ('id', PrimaryKeyField), + ('post', ForeignKeyField), + ('commenter', IntegerField), + ('comment', TextField)]) + + def test_backrefs(self): + trans = translate(User, backrefs=True) + self.assertEqual(sorted(trans.keys()), [ + 'Comment', + 'Post', + 'User']) + else: print_('Skipping djpeewee tests, Django not found.') From 304c054c9927728afa0e3dd0879e650c79372061 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Tue, 19 Nov 2013 04:56:36 -0600 Subject: [PATCH 36/39] Adding parameter to preserve limit/offset when wrapping count. --- peewee.py | 5 +++-- tests.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/peewee.py b/peewee.py index d41285631..2386dabdd 100644 --- a/peewee.py +++ b/peewee.py @@ -1657,9 +1657,10 @@ def count(self): # defaults to a count() of the primary key return self.aggregate(convert=False) or 0 - def wrapped_count(self): + def wrapped_count(self, clear_limit=True): clone = self.order_by() - clone._limit = clone._offset = None + if clear_limit: + clone._limit = clone._offset = None sql, params = clone.sql() wrapped = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql diff --git a/tests.py b/tests.py index 790f8ee29..dd019b0b6 100644 --- a/tests.py +++ b/tests.py @@ -1754,6 +1754,9 @@ def test_counting(self): uc = User.select().where(User.username == 'u1').join(Blog).distinct().count() self.assertEqual(uc, 1) + self.assertEqual( + User.select().limit(1).wrapped_count(clear_limit=False), 1) + def test_ordering(self): u1 = User.create(username='u1') u2 = User.create(username='u2') From 37d7d2aeb67285c689992356860f90b7094e8b4c Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Tue, 19 Nov 2013 05:33:49 -0600 Subject: [PATCH 37/39] Adding more docs. --- docs/peewee/playhouse.rst | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/peewee/playhouse.rst b/docs/peewee/playhouse.rst index 2a76b4071..399b17ba7 100644 --- a/docs/peewee/playhouse.rst +++ b/docs/peewee/playhouse.rst @@ -815,7 +815,7 @@ Foreign keys and Many-to-many relationships The :py:func:`translate` function will recursively traverse the graph of models and return a dictionary populated with everything it finds. Back-references are -not searched, only explicit foreign keys or many-to-manys. +not searched by default, but can be included by specifying ``backrefs=True``. Example: @@ -856,16 +856,24 @@ will get all users who belong to a group that starts with the letter "A": djpeewee API ^^^^^^^^^^^^ -.. py:function:: translate(*models) +.. py:function:: translate(*models, **options) Translate the given Django models into roughly equivalent peewee models suitable for use constructing queries. Foreign keys and many-to-many relationships will be followed and models generated, although back references are not traversed. :param models: One or more Django model classes. + :param options: A dictionary of options, see note below. :returns: A dict-like object containing the generated models, but which supports dotted-name style lookups. + The following are valid options: + + * ``recurse``: Follow foreign keys and many to many (default: ``True``). + * ``max_depth``: Maximum depth to recurse (default: ``None``, unlimited). + * ``backrefs``: Follow backrefs (default: ``False``). + * ``exclude``: A list of models to exclude. + .. _gfk: From 939739c2cc0f52df091d343e492122fd9e716c9b Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Tue, 19 Nov 2013 05:57:40 -0600 Subject: [PATCH 38/39] Small changes. --- playhouse/djpeewee.py | 3 ++- playhouse/tests_djpeewee.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/playhouse/djpeewee.py b/playhouse/djpeewee.py index d4ab8c8c1..24d2d309b 100644 --- a/playhouse/djpeewee.py +++ b/playhouse/djpeewee.py @@ -90,7 +90,8 @@ def _translate_model(self, attrs[model_field.name] = IntegerField( db_column=model_field.get_attname()) else: - related_name = model_field.related_query_name() + related_name = (model_field.rel.related_name or + model_field.related_query_name()) if related_name.endswith('+'): related_name = '__%s:%s:%s' % ( options, diff --git a/playhouse/tests_djpeewee.py b/playhouse/tests_djpeewee.py index eeec3970f..58c531c55 100644 --- a/playhouse/tests_djpeewee.py +++ b/playhouse/tests_djpeewee.py @@ -65,6 +65,13 @@ class B(models.Model): class C(models.Model): b = models.ForeignKey(B, related_name='cs') + class Parent(models.Model): + pass + + class Child(Parent): + pass + + class TestDjPeewee(unittest.TestCase): def assertFields(self, model, expected): self.assertEqual(len(model._meta.fields), len(expected)) @@ -239,6 +246,18 @@ def test_backrefs(self): 'Post', 'User']) + def test_inheritance(self): + trans = translate(Parent) + self.assertEqual(trans.keys(), ['Parent']) + self.assertFields(trans['Parent'], [ + ('id', PrimaryKeyField),]) + + trans = translate(Child) + self.assertEqual(sorted(trans.keys()), ['Child', 'Parent']) + self.assertFields(trans['Child'], [ + ('id', PrimaryKeyField), + ('parent_ptr', ForeignKeyField)]) + else: print_('Skipping djpeewee tests, Django not found.') From a7cb984163dbd78f064510aad5f49f3aeac58242 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Tue, 19 Nov 2013 05:58:13 -0600 Subject: [PATCH 39/39] Version bump. --- docs/conf.py | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index a1280467c..3fd57b360 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -50,9 +50,9 @@ # built documents. # # The short X.Y version. -version = '2.1.5' +version = '2.1.6' # The full version, including alpha/beta/rc tags. -release = '2.1.5' +release = '2.1.6' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/setup.py b/setup.py index e97860b7c..c4a3c3223 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name='peewee', - version="2.1.5", + version='2.1.6', description='a little orm', long_description=readme, author='Charles Leifer',