Skip to content

Commit

Permalink
feat: Support Sorting in Batch ConnectionFields & Deprecate UnsortedC…
Browse files Browse the repository at this point in the history
…onnectionField(#355)

* Enable sorting when batching is enabled

* Deprecate UnsortedSQLAlchemyConnectionField and resetting RelationshipLoader between queries

* Use field_name instead of column.key to build sort enum names to ensure the enum will get the actula field_name

* Adjust batching test to honor different selet in query structure in sqla1.2

* Ensure that UnsortedSQLAlchemyConnectionField skips sort argument if it gets passed.

* add test for batch sorting with custom ormfield

Co-authored-by: Sabar Dasgupta <[email protected]>
  • Loading branch information
PaulSchweizer and sabard committed Sep 9, 2022
1 parent bb7af4b commit 43df4eb
Show file tree
Hide file tree
Showing 6 changed files with 534 additions and 257 deletions.
178 changes: 103 additions & 75 deletions graphene_sqlalchemy/batching.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The dataloader uses "select in loading" strategy to load related entities."""
from typing import Any
from asyncio import get_event_loop
from typing import Any, Dict

import aiodataloader
import sqlalchemy
Expand All @@ -10,6 +11,90 @@
is_sqlalchemy_version_less_than)


class RelationshipLoader(aiodataloader.DataLoader):
cache = False

def __init__(self, relationship_prop, selectin_loader):
super().__init__()
self.relationship_prop = relationship_prop
self.selectin_loader = selectin_loader

async def batch_load_fn(self, parents):
"""
Batch loads the relationships of all the parents as one SQL statement.
There is no way to do this out-of-the-box with SQLAlchemy but
we can piggyback on some internal APIs of the `selectin`
eager loading strategy. It's a bit hacky but it's preferable
than re-implementing and maintainnig a big chunk of the `selectin`
loader logic ourselves.
The approach here is to build a regular query that
selects the parent and `selectin` load the relationship.
But instead of having the query emits 2 `SELECT` statements
when callling `all()`, we skip the first `SELECT` statement
and jump right before the `selectin` loader is called.
To accomplish this, we have to construct objects that are
normally built in the first part of the query in order
to call directly `SelectInLoader._load_for_path`.
TODO Move this logic to a util in the SQLAlchemy repo as per
SQLAlchemy's main maitainer suggestion.
See https://git.io/JewQ7
"""
child_mapper = self.relationship_prop.mapper
parent_mapper = self.relationship_prop.parent
session = Session.object_session(parents[0])

# These issues are very unlikely to happen in practice...
for parent in parents:
# assert parent.__mapper__ is parent_mapper
# All instances must share the same session
assert session is Session.object_session(parent)
# The behavior of `selectin` is undefined if the parent is dirty
assert parent not in session.dirty

# Should the boolean be set to False? Does it matter for our purposes?
states = [(sqlalchemy.inspect(parent), True) for parent in parents]

# For our purposes, the query_context will only used to get the session
query_context = None
if is_sqlalchemy_version_less_than('1.4'):
query_context = QueryContext(session.query(parent_mapper.entity))
else:
parent_mapper_query = session.query(parent_mapper.entity)
query_context = parent_mapper_query._compile_context()

if is_sqlalchemy_version_less_than('1.4'):
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
)
else:
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
None,
)
return [
getattr(parent, self.relationship_prop.key) for parent in parents
]


# Cache this across `batch_load_fn` calls
# This is so SQL string generation is cached under-the-hood via `bakery`
# Caching the relationship loader for each relationship prop.
RELATIONSHIP_LOADERS_CACHE: Dict[
sqlalchemy.orm.relationships.RelationshipProperty, RelationshipLoader
] = {}


def get_data_loader_impl() -> Any: # pragma: no cover
"""Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility,
aiodataloader is used in conjunction with older versions of graphene"""
Expand All @@ -25,80 +110,23 @@ def get_data_loader_impl() -> Any: # pragma: no cover


def get_batch_resolver(relationship_prop):
# Cache this across `batch_load_fn` calls
# This is so SQL string generation is cached under-the-hood via `bakery`
selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),))

class RelationshipLoader(aiodataloader.DataLoader):
cache = False

async def batch_load_fn(self, parents):
"""
Batch loads the relationships of all the parents as one SQL statement.
There is no way to do this out-of-the-box with SQLAlchemy but
we can piggyback on some internal APIs of the `selectin`
eager loading strategy. It's a bit hacky but it's preferable
than re-implementing and maintainnig a big chunk of the `selectin`
loader logic ourselves.
The approach here is to build a regular query that
selects the parent and `selectin` load the relationship.
But instead of having the query emits 2 `SELECT` statements
when callling `all()`, we skip the first `SELECT` statement
and jump right before the `selectin` loader is called.
To accomplish this, we have to construct objects that are
normally built in the first part of the query in order
to call directly `SelectInLoader._load_for_path`.
TODO Move this logic to a util in the SQLAlchemy repo as per
SQLAlchemy's main maitainer suggestion.
See https://git.io/JewQ7
"""
child_mapper = relationship_prop.mapper
parent_mapper = relationship_prop.parent
session = Session.object_session(parents[0])

# These issues are very unlikely to happen in practice...
for parent in parents:
# assert parent.__mapper__ is parent_mapper
# All instances must share the same session
assert session is Session.object_session(parent)
# The behavior of `selectin` is undefined if the parent is dirty
assert parent not in session.dirty

# Should the boolean be set to False? Does it matter for our purposes?
states = [(sqlalchemy.inspect(parent), True) for parent in parents]

# For our purposes, the query_context will only used to get the session
query_context = None
if is_sqlalchemy_version_less_than('1.4'):
query_context = QueryContext(session.query(parent_mapper.entity))
else:
parent_mapper_query = session.query(parent_mapper.entity)
query_context = parent_mapper_query._compile_context()

if is_sqlalchemy_version_less_than('1.4'):
selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper
)
else:
selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
None
)

return [getattr(parent, relationship_prop.key) for parent in parents]

loader = RelationshipLoader()
"""Get the resolve function for the given relationship."""

def _get_loader(relationship_prop):
"""Retrieve the cached loader of the given relationship."""
loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None)
if loader is None or loader.loop != get_event_loop():
selectin_loader = strategies.SelectInLoader(
relationship_prop, (('lazy', 'selectin'),)
)
loader = RelationshipLoader(
relationship_prop=relationship_prop,
selectin_loader=selectin_loader,
)
RELATIONSHIP_LOADERS_CACHE[relationship_prop] = loader
return loader

loader = _get_loader(relationship_prop)

async def resolve(root, info, **args):
return await loader.load(root)
Expand Down
4 changes: 2 additions & 2 deletions graphene_sqlalchemy/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def sort_enum_for_object_type(
column = orm_field.columns[0]
if only_indexed and not (column.primary_key or column.index):
continue
asc_name = get_name(column.key, True)
asc_name = get_name(field_name, True)
asc_value = EnumValue(asc_name, column.asc())
desc_name = get_name(column.key, False)
desc_name = get_name(field_name, False)
desc_value = EnumValue(desc_name, column.desc())
if column.primary_key:
default.append(asc_value)
Expand Down
116 changes: 69 additions & 47 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .utils import EnumValue, get_query


class UnsortedSQLAlchemyConnectionField(ConnectionField):
class SQLAlchemyConnectionField(ConnectionField):
@property
def type(self):
from .types import SQLAlchemyObjectType
Expand All @@ -37,13 +37,45 @@ def type(self):
)
return nullable_type.connection

def __init__(self, type_, *args, **kwargs):
nullable_type = get_nullable_type(type_)
if "sort" not in kwargs and nullable_type and issubclass(nullable_type, Connection):
# Let super class raise if type is not a Connection
try:
kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument())
except (AttributeError, TypeError):
raise TypeError(
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
" to None to disabling the creation of the sort query argument".format(
nullable_type.__name__
)
)
elif "sort" in kwargs and kwargs["sort"] is None:
del kwargs["sort"]
super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)

@property
def model(self):
return get_nullable_type(self.type)._meta.node._meta.model

@classmethod
def get_query(cls, model, info, **args):
return get_query(model, info.context)
def get_query(cls, model, info, sort=None, **args):
query = get_query(model, info.context)
if sort is not None:
if not isinstance(sort, list):
sort = [sort]
sort_args = []
# ensure consistent handling of graphene Enums, enum values and
# plain strings
for item in sort:
if isinstance(item, enum.Enum):
sort_args.append(item.value.value)
elif isinstance(item, EnumValue):
sort_args.append(item.value)
else:
sort_args.append(item)
query = query.order_by(*sort_args)
return query

@classmethod
def resolve_connection(cls, connection_type, model, info, args, resolved):
Expand Down Expand Up @@ -90,59 +122,49 @@ def wrap_resolve(self, parent_resolver):
)


# TODO Rename this to SortableSQLAlchemyConnectionField
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
# TODO Remove in next major version
class UnsortedSQLAlchemyConnectionField(SQLAlchemyConnectionField):
def __init__(self, type_, *args, **kwargs):
nullable_type = get_nullable_type(type_)
if "sort" not in kwargs and issubclass(nullable_type, Connection):
# Let super class raise if type is not a Connection
try:
kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument())
except (AttributeError, TypeError):
raise TypeError(
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
" to None to disabling the creation of the sort query argument".format(
nullable_type.__name__
)
)
elif "sort" in kwargs and kwargs["sort"] is None:
del kwargs["sort"]
super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)

@classmethod
def get_query(cls, model, info, sort=None, **args):
query = get_query(model, info.context)
if sort is not None:
if not isinstance(sort, list):
sort = [sort]
sort_args = []
# ensure consistent handling of graphene Enums, enum values and
# plain strings
for item in sort:
if isinstance(item, enum.Enum):
sort_args.append(item.value.value)
elif isinstance(item, EnumValue):
sort_args.append(item.value)
else:
sort_args.append(item)
query = query.order_by(*sort_args)
return query
if "sort" in kwargs and kwargs["sort"] is not None:
warnings.warn(
"UnsortedSQLAlchemyConnectionField does not support sorting. "
"All sorting arguments will be ignored."
)
kwargs["sort"] = None
warnings.warn(
"UnsortedSQLAlchemyConnectionField is deprecated and will be removed in the next "
"major version. Use SQLAlchemyConnectionField instead and either don't "
"provide the `sort` argument or set it to None if you do not want sorting.",
DeprecationWarning,
)
super(UnsortedSQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)


class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField):
"""
This is currently experimental.
The API and behavior may change in future versions.
Use at your own risk.
"""

def wrap_resolve(self, parent_resolver):
return partial(
self.connection_resolver,
self.resolver,
get_nullable_type(self.type),
self.model,
)
@classmethod
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
if root is None:
resolved = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
else:
relationship_prop = None
for relationship in root.__class__.__mapper__.relationships:
if relationship.mapper.class_ == model:
relationship_prop = relationship
break
resolved = get_batch_resolver(relationship_prop)(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection_type, root, info, args)

if is_thenable(resolved):
return Promise.resolve(resolved).then(on_resolve)

return on_resolve(resolved)

@classmethod
def from_relationship(cls, relationship, registry, **field_kwargs):
Expand Down
18 changes: 18 additions & 0 deletions graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,24 @@ class Article(Base):
headline = Column(String(100))
pub_date = Column(Date())
reporter_id = Column(Integer(), ForeignKey("reporters.id"))
readers = relationship(
"Reader", secondary="articles_readers", back_populates="articles"
)


class Reader(Base):
__tablename__ = "readers"
id = Column(Integer(), primary_key=True)
name = Column(String(100))
articles = relationship(
"Article", secondary="articles_readers", back_populates="readers"
)


class ArticleReader(Base):
__tablename__ = "articles_readers"
article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True)
reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True)


class ReflectedEditor(type):
Expand Down
Loading

0 comments on commit 43df4eb

Please sign in to comment.