Skip to content

Commit

Permalink
[Regression Fix] Call custom resolve functions if provided (#241)
Browse files Browse the repository at this point in the history
Fixes issue #234
  • Loading branch information
jnak committed Aug 14, 2019
1 parent c89cf80 commit a361c52
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 27 deletions.
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .fields import SQLAlchemyConnectionField
from .utils import get_query, get_session

__version__ = "2.2.1"
__version__ = "2.2.2"

__all__ = [
"__version__",
Expand Down
18 changes: 7 additions & 11 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
ChoiceType = JSONType = ScalarListType = TSVectorType = object


def _get_attr_resolver(attr_name):
return lambda root, _info: getattr(root, attr_name, None)


def get_column_doc(column):
return getattr(column, "doc", None)

Expand All @@ -28,7 +24,7 @@ def is_column_nullable(column):
return bool(getattr(column, "nullable", True))


def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, **field_kwargs):
def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, resolver, **field_kwargs):
direction = relationship_prop.direction
model = relationship_prop.mapper.entity

Expand All @@ -40,7 +36,7 @@ def dynamic_type():
if direction == interfaces.MANYTOONE or not relationship_prop.uselist:
return Field(
_type,
resolver=_get_attr_resolver(relationship_prop.key),
resolver=resolver,
**field_kwargs
)
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
Expand All @@ -55,18 +51,18 @@ def dynamic_type():
return Dynamic(dynamic_type)


def convert_sqlalchemy_hybrid_method(hybrid_prop, prop_name, **field_kwargs):
def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs):
if 'type' not in field_kwargs:
# TODO The default type should be dependent on the type of the property propety.
field_kwargs['type'] = String

return Field(
resolver=_get_attr_resolver(prop_name),
resolver=resolver,
**field_kwargs
)


def convert_sqlalchemy_composite(composite_prop, registry):
def convert_sqlalchemy_composite(composite_prop, registry, resolver):
converter = registry.get_converter_for_composite(composite_prop.composite_class)
if not converter:
try:
Expand Down Expand Up @@ -100,14 +96,14 @@ def inner(fn):
convert_sqlalchemy_composite.register = _register_composite_class


def convert_sqlalchemy_column(column_prop, registry, **field_kwargs):
def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs):
column = column_prop.columns[0]
field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry))
field_kwargs.setdefault('required', not is_column_nullable(column))
field_kwargs.setdefault('description', get_column_doc(column))

return Field(
resolver=_get_attr_resolver(column_prop.key),
resolver=resolver,
**field_kwargs
)

Expand Down
4 changes: 3 additions & 1 deletion graphene_sqlalchemy/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker

import graphene

from ..converter import convert_sqlalchemy_composite
from ..registry import reset_global_registry
from .models import Base, CompositeFullName
Expand All @@ -17,7 +19,7 @@ def reset_registry():
# Tests that explicitly depend on this behavior should re-register a converter
@convert_sqlalchemy_composite.register(CompositeFullName)
def convert_composite_class(composite, registry):
pass
return graphene.Field(graphene.Int)


@pytest.yield_fixture(scope="function")
Expand Down
32 changes: 24 additions & 8 deletions graphene_sqlalchemy/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@
from .models import Article, CompositeFullName, Pet, Reporter


def mock_resolver():
pass


def get_field(sqlalchemy_type, **column_kwargs):
class Model(declarative_base()):
__tablename__ = 'model'
id_ = Column(types.Integer, primary_key=True)
column = Column(sqlalchemy_type, doc="Custom Help Text", **column_kwargs)

column_prop = inspect(Model).column_attrs['column']
return convert_sqlalchemy_column(column_prop, get_global_registry())
return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver)


def get_field_from_column(column_):
Expand All @@ -40,7 +44,7 @@ class Model(declarative_base()):
column = column_

column_prop = inspect(Model).column_attrs['column']
return convert_sqlalchemy_column(column_prop, get_global_registry())
return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver)


def test_should_unknown_sqlalchemy_field_raise_exception():
Expand Down Expand Up @@ -162,7 +166,7 @@ def test_should_jsontype_convert_jsonstring():
def test_should_manytomany_convert_connectionorlist():
registry = Registry()
dynamic_field = convert_sqlalchemy_relationship(
Reporter.pets.property, registry, default_connection_field_factory
Reporter.pets.property, registry, default_connection_field_factory, mock_resolver,
)
assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type()
Expand All @@ -174,7 +178,7 @@ class Meta:
model = Pet

dynamic_field = convert_sqlalchemy_relationship(
Reporter.pets.property, A._meta.registry, default_connection_field_factory
Reporter.pets.property, A._meta.registry, default_connection_field_factory, mock_resolver,
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand All @@ -190,7 +194,7 @@ class Meta:
interfaces = (Node,)

dynamic_field = convert_sqlalchemy_relationship(
Reporter.pets.property, A._meta.registry, default_connection_field_factory
Reporter.pets.property, A._meta.registry, default_connection_field_factory, mock_resolver
)
assert isinstance(dynamic_field, graphene.Dynamic)
assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField)
Expand All @@ -199,7 +203,10 @@ class Meta:
def test_should_manytoone_convert_connectionorlist():
registry = Registry()
dynamic_field = convert_sqlalchemy_relationship(
Article.reporter.property, registry, default_connection_field_factory
Article.reporter.property,
registry,
default_connection_field_factory,
mock_resolver,
)
assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type()
Expand All @@ -211,7 +218,10 @@ class Meta:
model = Reporter

dynamic_field = convert_sqlalchemy_relationship(
Article.reporter.property, A._meta.registry, default_connection_field_factory
Article.reporter.property,
A._meta.registry,
default_connection_field_factory,
mock_resolver,
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand All @@ -226,7 +236,10 @@ class Meta:
interfaces = (Node,)

dynamic_field = convert_sqlalchemy_relationship(
Article.reporter.property, A._meta.registry, default_connection_field_factory
Article.reporter.property,
A._meta.registry,
default_connection_field_factory,
mock_resolver,
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand All @@ -244,6 +257,7 @@ class Meta:
Reporter.favorite_article.property,
A._meta.registry,
default_connection_field_factory,
mock_resolver,
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand Down Expand Up @@ -310,6 +324,7 @@ def convert_composite_class(composite, registry):
field = convert_sqlalchemy_composite(
composite(CompositeClass, (Column(types.Unicode(50)), Column(types.Unicode(50))), doc="Custom Help Text"),
registry,
mock_resolver,
)
assert isinstance(field, graphene.String)

Expand All @@ -325,4 +340,5 @@ def __init__(self, col1, col2):
convert_sqlalchemy_composite(
composite(CompositeFullName, (Column(types.Unicode(50)), Column(types.Unicode(50)))),
Registry(),
mock_resolver,
)
72 changes: 70 additions & 2 deletions graphene_sqlalchemy/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import six # noqa F401

from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull,
ObjectType, String)
ObjectType, Schema, String)

from ..converter import convert_sqlalchemy_composite
from ..fields import (SQLAlchemyConnectionField,
Expand Down Expand Up @@ -264,6 +264,7 @@ class Meta:
"column_prop",
"email",
"favorite_pet_kind",
"composite_prop",
"hybrid_prop",
"pets",
"articles",
Expand Down Expand Up @@ -293,6 +294,73 @@ class Meta:
assert first_name_field.type == Int


def test_resolvers(session):
"""Test that the correct resolver functions are called"""

class ReporterMixin(object):
def resolve_id(root, _info):
return 'ID'

class ReporterType(ReporterMixin, SQLAlchemyObjectType):
class Meta:
model = Reporter

email = ORMField()
email_v2 = ORMField(model_attr='email')
favorite_pet_kind = Field(String)
favorite_pet_kind_v2 = Field(String)

def resolve_last_name(root, _info):
return root.last_name.upper()

def resolve_email_v2(root, _info):
return root.email + '_V2'

def resolve_favorite_pet_kind_v2(root, _info):
return str(root.favorite_pet_kind) + '_V2'

class Query(ObjectType):
reporter = Field(ReporterType)

def resolve_reporter(self, _info):
return session.query(Reporter).first()

reporter = Reporter(first_name='first_name', last_name='last_name', email='email', favorite_pet_kind='cat')
session.add(reporter)
session.commit()

schema = Schema(query=Query)
result = schema.execute("""
query {
reporter {
id
firstName
lastName
email
emailV2
favoritePetKind
favoritePetKindV2
}
}
""")

assert not result.errors
# Custom resolver on a base class
assert result.data['reporter']['id'] == 'ID'
# Default field + default resolver
assert result.data['reporter']['firstName'] == 'first_name'
# Default field + custom resolver
assert result.data['reporter']['lastName'] == 'LAST_NAME'
# ORMField + default resolver
assert result.data['reporter']['email'] == 'email'
# ORMField + custom resolver
assert result.data['reporter']['emailV2'] == 'email_V2'
# Field + default resolver
assert result.data['reporter']['favoritePetKind'] == 'cat'
# Field + custom resolver
assert result.data['reporter']['favoritePetKindV2'] == 'cat_V2'


# Test Custom SQLAlchemyObjectType Implementation

def test_custom_objecttype_registered():
Expand All @@ -306,7 +374,7 @@ class Meta:

assert issubclass(CustomReporterType, ObjectType)
assert CustomReporterType._meta.model == Reporter
assert len(CustomReporterType._meta.fields) == 10
assert len(CustomReporterType._meta.fields) == 11


# Test Custom SQLAlchemyObjectType with Custom Options
Expand Down
30 changes: 26 additions & 4 deletions graphene_sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from graphene.relay import Connection, Node
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs
from graphene.utils.get_unbound_function import get_unbound_function
from graphene.utils.orderedtype import OrderedType

from .converter import (convert_sqlalchemy_column,
Expand Down Expand Up @@ -151,20 +152,22 @@ def construct_fields(
for orm_field_name, orm_field in orm_fields.items():
attr_name = orm_field.kwargs.pop('model_attr')
attr = all_model_attrs[attr_name]
resolver = _get_field_resolver(obj_type, orm_field_name, attr_name)

if isinstance(attr, ColumnProperty):
field = convert_sqlalchemy_column(attr, registry, **orm_field.kwargs)
field = convert_sqlalchemy_column(attr, registry, resolver, **orm_field.kwargs)
elif isinstance(attr, RelationshipProperty):
field = convert_sqlalchemy_relationship(attr, registry, connection_field_factory, **orm_field.kwargs)
field = convert_sqlalchemy_relationship(attr, registry, connection_field_factory, resolver,
**orm_field.kwargs)
elif isinstance(attr, CompositeProperty):
if attr_name != orm_field_name or orm_field.kwargs:
# TODO Add a way to override composite property fields
raise ValueError(
"ORMField kwargs for composite fields must be empty. "
"Field: {}.{}".format(obj_type.__name__, orm_field_name))
field = convert_sqlalchemy_composite(attr, registry)
field = convert_sqlalchemy_composite(attr, registry, resolver)
elif isinstance(attr, hybrid_property):
field = convert_sqlalchemy_hybrid_method(attr, attr_name, **orm_field.kwargs)
field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs)
else:
raise Exception('Property type is not supported') # Should never happen

Expand All @@ -174,6 +177,25 @@ def construct_fields(
return fields


def _get_field_resolver(obj_type, orm_field_name, model_attr):
"""
In order to support field renaming via `ORMField.model_attr`,
we need to define resolver functions for each field.
:param SQLAlchemyObjectType obj_type:
:param model: the SQLAlchemy model
:param str model_attr: the name of SQLAlchemy of the attribute used to resolve the field
:rtype: Callable
"""
# Since `graphene` will call `resolve_<field_name>` on a field only if it
# does not have a `resolver`, we need to re-implement that logic here.
resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None)
if resolver:
return get_unbound_function(resolver)

return lambda root, _info: getattr(root, model_attr, None)


class SQLAlchemyObjectTypeOptions(ObjectTypeOptions):
model = None # type: sqlalchemy.Model
registry = None # type: sqlalchemy.Registry
Expand Down

0 comments on commit a361c52

Please sign in to comment.