Skip to content

Commit

Permalink
Merge pull request #24 from autoinvent/fragments
Browse files Browse the repository at this point in the history
handle fragments in relationship loader
  • Loading branch information
davidism authored Aug 2, 2024
2 parents a0c3caa + d4a05a6 commit 854fec2
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 33 deletions.
4 changes: 3 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
Unreleased

- The description for a model's object, and for an attribute's field and
argument, is set from their docstrings. :issue:`19`
argument, is set from their docstrings. {issue}`19`
- Handle fragments when inspecting query to load relationships. {issue}`21`
- Clearer error when SQLAlchemy session is not passed in GraphQL context.


## Version 1.0.0
Expand Down
143 changes: 111 additions & 32 deletions src/magql_sqlalchemy/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def __init__(self, model: type[t.Any]) -> None:

def _load_relationships(
self,
node: graphql.FieldNode,
info: graphql.GraphQLResolveInfo,
node: graphql.FieldNode
| graphql.FragmentDefinitionNode
| graphql.InlineFragmentNode,
model: t.Any,
load_path: orm.Load | None = None,
) -> list[orm.Load]:
Expand All @@ -44,18 +47,35 @@ def _load_relationships(
efficient by letting SQLAlchemy preload related data rather than issuing
individual queries for every attribute access.
:param node: The AST node representing the GraphQL operation.
:param info: The GraphQL info about the operation, which contains the
fragment references.
:param node: The AST node being inspected.
:param model: The model containing the relationships. Starts as the model for
this resolver, then the relationship's target model during recursion.
:param load_path: During recursion, the SQLAlchemy load that has been performed
to get to this relationship and should be extended.
.. versionchanged:: 1.1
Handle fragments.
"""
if node.selection_set is None:
return []

out = []

for selection in node.selection_set.selections:
if isinstance(selection, graphql.FragmentSpreadNode):
# Fragments are an extra nested level. Find the definition,
# recurse, then continue.
fragment = info.fragments[selection.name.value]
out.extend(self._load_relationships(info, fragment, model, load_path))
continue

if isinstance(selection, graphql.InlineFragmentNode):
# Inline fragments are an extra nested level. Recurse, then continue.
out.extend(self._load_relationships(info, selection, model, load_path))
continue

inner_node = t.cast(graphql.FieldNode, selection)

# Only consider AST nodes for relationships, which are ones with further
Expand Down Expand Up @@ -84,7 +104,7 @@ def _load_relationships(
# Recurse to find any relationship fields selected in the child object.
out.extend(
self._load_relationships(
inner_node, rel_prop.entity.class_, extended_path
info, inner_node, rel_prop.entity.class_, extended_path
)
)

Expand Down Expand Up @@ -126,7 +146,8 @@ def __call__(
) -> t.Any:
"""Build and execute the query, then return the result."""
query = self.build_query(parent, info, **kwargs)
result = info.context["sa_session"].execute(query)
session = _get_sa_session(info)
result = session.execute(query)
return self.transform_result(result)


Expand All @@ -141,7 +162,8 @@ class ItemResolver(QueryResolver):
def build_query(
self, parent: t.Any, info: graphql.GraphQLResolveInfo, **kwargs: t.Any
) -> sql.Select[t.Any]:
load = self._load_relationships(_get_field_node(info), self.model)
field_node = _get_field_node(info)
load = self._load_relationships(info, field_node, self.model)
return (
sa.select(self.model)
.options(*load)
Expand Down Expand Up @@ -266,8 +288,8 @@ def apply_page(
def build_query(
self, parent: t.Any, info: graphql.GraphQLResolveInfo, **kwargs: t.Any
) -> sql.Select[t.Any]:
field_node = _get_field_node(info, nested="items")
load = self._load_relationships(field_node, self.model)
field_node = _get_field_node(info, list_name="items")
load = self._load_relationships(info, field_node, self.model)
query = sa.select(self.model).options(*load)
query = self.apply_filter(query, kwargs.get("filter"))
query = self.apply_sort(query, kwargs.get("sort"))
Expand Down Expand Up @@ -300,7 +322,7 @@ def __call__(
self, parent: t.Any, info: graphql.GraphQLResolveInfo, **kwargs: t.Any
) -> t.Any:
query = self.build_query(parent, info, **kwargs)
session: orm.Session = info.context["sa_session"]
session = _get_sa_session(info)
items = self.get_items(session, query)
total = self.get_count(session, query)
return ListResult(items=items, total=total)
Expand All @@ -321,28 +343,77 @@ class ListResult:


def _get_field_node(
info: graphql.GraphQLResolveInfo, nested: str | None = None
info: graphql.GraphQLResolveInfo, list_name: str | None = None
) -> graphql.FieldNode:
"""Get the node that describes the fields being selected by the current query. This
is used to determine if any of the fields are relationships to load.
"""Get the node that describes the fields being selected by the current
query. The returned node's AST is later scanned to load any relationships.
Assumes a single top-level field.
:param info: The GraphQL info about the operation, which contains the AST.
:param list_name: For a list query, the name of the field containing the
list of results. Should be ``"items"``.
:param info: The GrapQL info about the operation, which contains the AST.
:param nested: For the list query, the name of the field containing the list of
results. Should be ``"items"``.
.. versionchanged:: 1.1
Handle fragments.
"""
node = info.field_nodes[0]
# TODO handle multiple top-level fields

# For a list query, the actual type is nested in the list result type.
if nested is not None:
assert node.selection_set is not None
# For a list query, the items field is nested in the list result type.
if list_name is not None:
return _get_list_root(info, node, list_name)

for selection in node.selection_set.selections:
# Don't need to handle fragments here because top-level fragments like
# `query { ...fragment }` are already dereferenced in info.field_nodes.
return node


def _get_list_root(
info: graphql.GraphQLResolveInfo,
node: graphql.FieldNode
| graphql.FragmentDefinitionNode
| graphql.InlineFragmentNode,
name: str,
) -> graphql.FieldNode:
"""Scan the selected fields within a node to find the items field in a list
result type. Handle fragments by recursively scanning through references.
:param info: The GraphQL info about the operation, which contains the
fragment references.
:param node: The node being scanned.
:param name: The name of the field containing the list of results.
.. versionadded:: 1.1
Added for easier recursion when handling fragments.
"""
assert node.selection_set is not None

for selection in node.selection_set.selections:
if isinstance(selection, graphql.FragmentSpreadNode):
# Fragments are an extra nested level, recurse.
fragment = info.fragments[selection.name.value]
result = _get_list_root(info, fragment, name)

if result is not fragment:
return result

elif isinstance(selection, graphql.InlineFragmentNode):
# Inline fragments are an extra nested level, recurse.
result = _get_list_root(info, selection, name)

if (result := _get_list_root(info, selection, name)) is not selection:
return result

else:
inner_node = t.cast(graphql.FieldNode, selection)

if inner_node.name.value == nested:
if inner_node.name.value == name:
return inner_node

return node
# Don't know how to inspect this node further, return it directly.
# This cast will eventually be right when recursion ends.
return t.cast(graphql.FieldNode, node)


class MutationResolver(ModelResolver):
Expand All @@ -354,15 +425,14 @@ def get_item(
self, info: graphql.GraphQLResolveInfo, kwargs: dict[str, t.Any]
) -> t.Any:
"""Get the model instance by primary key value."""
return (
info.context["sa_session"]
.execute(
sa.select(self.model)
.options(*self._load_relationships(_get_field_node(info), self.model))
.where(self.pk_col == kwargs[self.pk_name])
)
.scalar_one()
)
session = _get_sa_session(info)
field_node = _get_field_node(info)
load = self._load_relationships(info, field_node, self.model)
return session.execute(
sa.select(self.model)
.options(*load)
.where(self.pk_col == kwargs[self.pk_name])
).scalar_one()

def apply_related(self, session: orm.Session, kwargs: dict[str, t.Any]) -> None:
"""For all relationship arguments, replace the id values with their model
Expand Down Expand Up @@ -408,7 +478,7 @@ class CreateResolver(MutationResolver):
def __call__(
self, parent: t.Any, info: graphql.GraphQLResolveInfo, **kwargs: t.Any
) -> t.Any:
session: orm.Session = info.context["sa_session"]
session = _get_sa_session(info)
self.apply_related(session, kwargs)
obj = self.model(**kwargs)
session.add(obj)
Expand All @@ -429,7 +499,7 @@ class UpdateResolver(MutationResolver):
def __call__(
self, parent: t.Any, info: graphql.GraphQLResolveInfo, **kwargs: t.Any
) -> t.Any:
session: orm.Session = info.context["sa_session"]
session = _get_sa_session(info)
self.apply_related(session, kwargs)
item = self.get_item(info, kwargs)

Expand All @@ -454,8 +524,17 @@ class DeleteResolver(MutationResolver):
def __call__(
self, parent: t.Any, info: graphql.GraphQLResolveInfo, **kwargs: t.Any
) -> t.Any:
session: orm.Session = info.context["sa_session"]
session = _get_sa_session(info)
item = self.get_item(info, kwargs)
session.delete(item)
session.commit()
return True


def _get_sa_session(info: graphql.GraphQLResolveInfo) -> orm.Session:
"""Get the SQLAlchemy session from the context."""

try:
return info.context["sa_session"] # type: ignore[no-any-return]
except (TypeError, KeyError) as e:
raise RuntimeError("'sa_session' must be set in execute context dict.") from e
85 changes: 85 additions & 0 deletions tests/test_get_field_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

import typing as t
from types import SimpleNamespace

from graphql import FieldNode
from graphql import GraphQLResolveInfo
from magql import Object
from magql import Schema

from magql_sqlalchemy.resolvers import _get_field_node

schema = Schema()


@schema.query.field("item", "Boolean!")
def _resolve_one(parent: t.Any, info: GraphQLResolveInfo, **kwargs: t.Any) -> bool:
info.context.node = _get_field_node(info)
return True


@schema.query.field(
"list", Object("ListResult", fields={"items": "[Boolean!]!", "total": "Int!"})
)
def _resolve_list(
parent: t.Any, info: GraphQLResolveInfo, **kwargs: t.Any
) -> SimpleNamespace:
info.context.node = _get_field_node(info, "items")
return SimpleNamespace(items=[], total=0)


def _execute(source: str) -> FieldNode:
"""Execute a query and return the result of _get_field_node. The resolvers
store the result in the context object.
"""
context = SimpleNamespace()
result = schema.execute(source, context=context)

if result.errors:
raise result.errors[0]

return t.cast(FieldNode, context.node)


def test_item_node() -> None:
"""The first node is found."""
node = _execute("query { item }")
assert node.name.value == "item"


def test_list_node() -> None:
"""The items node in a list result is found."""
node = _execute("query { list { total items } }")
assert node.name.value == "items"


def test_list_fragment() -> None:
"""Fragment on list result is handled by recursing into dereferenced nodes."""
node = _execute(
"fragment a on ListResult { items }\n"
"fragment b on ListResult { total ...a }\n"
"query { list { ...b } }"
)
assert node.name.value == "items"


def test_inline_fragment() -> None:
"""Inline fragment on list result is handled by recursing into the node."""
node = _execute(
"fragment a on ListResult { items }\n"
"query { list { ... on ListResult { total ...a } } }"
)
assert node.name.value == "items"


def test_top_fragment() -> None:
"""Fragments on Query are automatically flattened during parsing and don't
need to be handled specially.
"""
node = _execute(
"fragment a on Query { item }\n"
"fragment b on Query { ...a }\n"
"query { ...b }"
)
assert node.name.value == "item"

0 comments on commit 854fec2

Please sign in to comment.