Skip to content

Commit

Permalink
feat: support for async sessions (#350)
Browse files Browse the repository at this point in the history
* feat(async): add support for async sessions
This PR brings experimental support for async sessions in SQLAlchemyConnectionFields. Batching is not yet supported and will be subject to a later PR.
Co-authored-by: Jendrik <[email protected]>
Co-authored-by: Erik Wrede <[email protected]>
  • Loading branch information
jendrikjoe committed Dec 21, 2022
1 parent 2edeae9 commit 32d0d18
Show file tree
Hide file tree
Showing 18 changed files with 931 additions and 332 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Tests

on:
on:
push:
branches:
- 'master'
Expand Down
66 changes: 52 additions & 14 deletions docs/inheritance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Inheritance Examples

Create interfaces from inheritance relationships
------------------------------------------------

.. note:: If you're using `AsyncSession`, please check the chapter `Eager Loading & Using with AsyncSession`_.
SQLAlchemy has excellent support for class inheritance hierarchies.
These hierarchies can be represented in your GraphQL schema by means
of interfaces_. Much like ObjectTypes, Interfaces in
Expand Down Expand Up @@ -40,7 +40,7 @@ from the attributes of their underlying SQLAlchemy model:
__mapper_args__ = {
"polymorphic_identity": "employee",
}
class Customer(Person):
first_purchase_date = Column(Date())
Expand All @@ -56,17 +56,17 @@ from the attributes of their underlying SQLAlchemy model:
class Meta:
model = Employee
interfaces = (relay.Node, PersonType)
class CustomerType(SQLAlchemyObjectType):
class Meta:
model = Customer
interfaces = (relay.Node, PersonType)
Keep in mind that `PersonType` is a `SQLAlchemyInterface`. Interfaces must
be linked to an abstract Model that does not specify a `polymorphic_identity`,
because we cannot return instances of interfaces from a GraphQL query.
If Person specified a `polymorphic_identity`, instances of Person could
be inserted into and returned by the database, potentially causing
Keep in mind that `PersonType` is a `SQLAlchemyInterface`. Interfaces must
be linked to an abstract Model that does not specify a `polymorphic_identity`,
because we cannot return instances of interfaces from a GraphQL query.
If Person specified a `polymorphic_identity`, instances of Person could
be inserted into and returned by the database, potentially causing
Persons to be returned to the resolvers.

When querying on the base type, you can refer directly to common fields,
Expand All @@ -85,15 +85,19 @@ and fields on concrete implementations using the `... on` syntax:
firstPurchaseDate
}
}
.. danger::
When using joined table inheritance, this style of querying may lead to unbatched implicit IO with negative performance implications.
See the chapter `Eager Loading & Using with AsyncSession`_ for more information on eager loading all possible types of a `SQLAlchemyInterface`.

Please note that by default, the "polymorphic_on" column is *not*
generated as a field on types that use polymorphic inheritance, as
this is considered an implentation detail. The idiomatic way to
this is considered an implementation detail. The idiomatic way to
retrieve the concrete GraphQL type of an object is to query for the
`__typename` field.
`__typename` field.
To override this behavior, an `ORMField` needs to be created
for the custom type field on the corresponding `SQLAlchemyInterface`. This is *not recommended*
for the custom type field on the corresponding `SQLAlchemyInterface`. This is *not recommended*
as it promotes abiguous schema design

If your SQLAlchemy model only specifies a relationship to the
Expand All @@ -103,5 +107,39 @@ class to the Schema constructor via the `types=` argument:
.. code:: python
schema = graphene.Schema(..., types=[PersonType, EmployeeType, CustomerType])
See also: `Graphene Interfaces <https://docs.graphene-python.org/en/latest/types/interfaces/>`_

Eager Loading & Using with AsyncSession
--------------------
When querying the base type in multi-table inheritance or joined table inheritance, you can only directly refer to polymorphic fields when they are loaded eagerly.
This restricting is in place because AsyncSessions don't allow implicit async operations such as the loads of the joined tables.
To load the polymorphic fields eagerly, you can use the `with_polymorphic` attribute of the mapper args in the base model:

.. code:: python
class Person(Base):
id = Column(Integer(), primary_key=True)
type = Column(String())
name = Column(String())
birth_date = Column(Date())
__tablename__ = "person"
__mapper_args__ = {
"polymorphic_on": type,
"with_polymorphic": "*", # needed for eager loading in async session
}
Alternatively, the specific polymorphic fields can be loaded explicitly in resolvers:

.. code:: python
class Query(graphene.ObjectType):
people = graphene.Field(graphene.List(PersonType))
async def resolve_people(self, _info):
return (await session.scalars(with_polymorphic(Person, [Engineer, Customer]))).all()
Dynamic batching of the types based on the query to avoid eager is currently not supported, but could be implemented in a future PR.

For more information on loading techniques for polymorphic models, please check out the `SQLAlchemy docs <https://docs.sqlalchemy.org/en/20/orm/queryguide/inheritance.html>`_.
13 changes: 6 additions & 7 deletions graphene_sqlalchemy/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy.orm import Session, strategies
from sqlalchemy.orm.query import QueryContext

from .utils import is_graphene_version_less_than, is_sqlalchemy_version_less_than
from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_graphene_version_less_than


def get_data_loader_impl() -> Any: # pragma: no cover
Expand Down Expand Up @@ -71,19 +71,19 @@ async def batch_load_fn(self, parents):

# For our purposes, the query_context will only used to get the session
query_context = None
if is_sqlalchemy_version_less_than("1.4"):
query_context = QueryContext(session.query(parent_mapper.entity))
else:
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
parent_mapper_query = session.query(parent_mapper.entity)
query_context = parent_mapper_query._compile_context()

if is_sqlalchemy_version_less_than("1.4"):
else:
query_context = QueryContext(session.query(parent_mapper.entity))
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
None,
)
else:
self.selectin_loader._load_for_path(
Expand All @@ -92,7 +92,6 @@ async def batch_load_fn(self, parents):
states,
None,
child_mapper,
None,
)
return [getattr(parent, self.relationship_prop.key) for parent in parents]

Expand Down
50 changes: 47 additions & 3 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from graphql_relay import connection_from_array_slice

from .batching import get_batch_resolver
from .utils import EnumValue, get_query
from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, EnumValue, get_query, get_session

if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
from sqlalchemy.ext.asyncio import AsyncSession


class SQLAlchemyConnectionField(ConnectionField):
Expand Down Expand Up @@ -81,8 +84,49 @@ def get_query(cls, model, info, sort=None, **args):

@classmethod
def resolve_connection(cls, connection_type, model, info, args, resolved):
session = get_session(info.context)
if resolved is None:
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4 and isinstance(session, AsyncSession):

async def get_result():
return await cls.resolve_connection_async(
connection_type, model, info, args, resolved
)

return get_result()

else:
resolved = cls.get_query(model, info, **args)
if isinstance(resolved, Query):
_len = resolved.count()
else:
_len = len(resolved)

def adjusted_connection_adapter(edges, pageInfo):
return connection_adapter(connection_type, edges, pageInfo)

connection = connection_from_array_slice(
array_slice=resolved,
args=args,
slice_start=0,
array_length=_len,
array_slice_length=_len,
connection_type=adjusted_connection_adapter,
edge_type=connection_type.Edge,
page_info_type=page_info_adapter,
)
connection.iterable = resolved
connection.length = _len
return connection

@classmethod
async def resolve_connection_async(
cls, connection_type, model, info, args, resolved
):
session = get_session(info.context)
if resolved is None:
resolved = cls.get_query(model, info, **args)
query = cls.get_query(model, info, **args)
resolved = (await session.scalars(query)).all()
if isinstance(resolved, Query):
_len = resolved.count()
else:
Expand Down Expand Up @@ -179,7 +223,7 @@ def from_relationship(cls, relationship, registry, **field_kwargs):
return cls(
model_type.connection,
resolver=get_batch_resolver(relationship),
**field_kwargs
**field_kwargs,
)


Expand Down
48 changes: 41 additions & 7 deletions graphene_sqlalchemy/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import pytest
import pytest_asyncio
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

import graphene
from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4

from ..converter import convert_sqlalchemy_composite
from ..registry import reset_global_registry
from .models import Base, CompositeFullName

test_db_url = "sqlite:https://" # use in-memory database for tests
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine


@pytest.fixture(autouse=True)
Expand All @@ -22,18 +25,49 @@ def convert_composite_class(composite, registry):
return graphene.Field(graphene.Int)


@pytest.fixture(scope="function")
def session_factory():
engine = create_engine(test_db_url)
Base.metadata.create_all(engine)
@pytest.fixture(params=[False, True])
def async_session(request):
return request.param


@pytest.fixture
def test_db_url(async_session: bool):
if async_session:
return "sqlite+aiosqlite:https://"
else:
return "sqlite:https://"

yield sessionmaker(bind=engine)

@pytest.mark.asyncio
@pytest_asyncio.fixture(scope="function")
async def session_factory(async_session: bool, test_db_url: str):
if async_session:
if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
pytest.skip("Async Sessions only work in sql alchemy 1.4 and above")
engine = create_async_engine(test_db_url)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
await engine.dispose()
else:
engine = create_engine(test_db_url)
Base.metadata.create_all(engine)
yield sessionmaker(bind=engine, expire_on_commit=False)
# SQLite in-memory db is deleted when its connection is closed.
# https://www.sqlite.org/inmemorydb.html
engine.dispose()


@pytest_asyncio.fixture(scope="function")
async def sync_session_factory():
engine = create_engine("sqlite:https://")
Base.metadata.create_all(engine)
yield sessionmaker(bind=engine, expire_on_commit=False)
# SQLite in-memory db is deleted when its connection is closed.
# https://www.sqlite.org/inmemorydb.html
engine.dispose()


@pytest.fixture(scope="function")
@pytest_asyncio.fixture(scope="function")
def session(session_factory):
return session_factory()
17 changes: 13 additions & 4 deletions graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import column_property, composite, mapper, relationship
from sqlalchemy.orm import backref, column_property, composite, mapper, relationship

PetKind = Enum("cat", "dog", name="pet_kind")

Expand Down Expand Up @@ -76,10 +76,16 @@ class Reporter(Base):
email = Column(String(), doc="Email")
favorite_pet_kind = Column(PetKind)
pets = relationship(
"Pet", secondary=association_table, backref="reporters", order_by="Pet.id"
"Pet",
secondary=association_table,
backref="reporters",
order_by="Pet.id",
lazy="selectin",
)
articles = relationship("Article", backref="reporter")
favorite_article = relationship("Article", uselist=False)
articles = relationship(
"Article", backref=backref("reporter", lazy="selectin"), lazy="selectin"
)
favorite_article = relationship("Article", uselist=False, lazy="selectin")

@hybrid_property
def hybrid_prop_with_doc(self):
Expand Down Expand Up @@ -304,8 +310,10 @@ class Person(Base):
__tablename__ = "person"
__mapper_args__ = {
"polymorphic_on": type,
"with_polymorphic": "*", # needed for eager loading in async session
}


class NonAbstractPerson(Base):
id = Column(Integer(), primary_key=True)
type = Column(String())
Expand All @@ -318,6 +326,7 @@ class NonAbstractPerson(Base):
"polymorphic_identity": "person",
}


class Employee(Person):
hire_date = Column(Date())

Expand Down
Loading

0 comments on commit 32d0d18

Please sign in to comment.