Skip to content

Commit

Permalink
Merge pull request #33 from autoinvent/override-resolvers
Browse files Browse the repository at this point in the history
ovverride resolver factories
  • Loading branch information
davidism authored Aug 3, 2024
2 parents c6a8fa9 + 87c7431 commit e389fbb
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Unreleased
extra behavior. {issue}`25`
- Resolver classes and `ModelManager`, and their methods, are generic on the
model class passed to them.
- `ModelManager` has class attributes to override the
item/list/create/update/delete resolver factories. `ModelGroup` has a class
attribute to override the manager class. This can be used to customize the
default behaviors. {issue}`26`


## Version 1.0.0
Expand Down
8 changes: 7 additions & 1 deletion src/magql_sqlalchemy/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ class ModelGroup:
:param managers: The model managers that are part of this group.
"""

manager_class: t.ClassVar[type[ModelManager[t.Any]]] = ModelManager
"""The manager class to use for each model.
.. versionadded:: 1.1
"""

def __init__(self, managers: list[ModelManager[t.Any]] | None = None) -> None:
self.managers: dict[str, ModelManager[t.Any]] = {}
"""Maps SQLAlchemy model names to their :class:`ModelManager` instance. Use
Expand Down Expand Up @@ -61,7 +67,7 @@ def from_declarative_base(
else:
model_search = search

managers.append(ModelManager(model, search=model_search))
managers.append(cls.manager_class(model, search=model_search))

return cls(managers)

Expand Down
51 changes: 46 additions & 5 deletions src/magql_sqlalchemy/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
M = t.TypeVar("M", bound=sa_orm.DeclarativeBase)


class ResolverFactory(t.Protocol):
def __call__(
self, model: type[sa_orm.DeclarativeBase]
) -> magql.nodes.ResolverCallable: ...


class ModelManager(t.Generic[M]):
"""The API for a single SQLAlchemy model class. Generates Magql types, fields,
resolvers, etc. These are exposed as attributes on this manager, and can be further
Expand All @@ -39,6 +45,41 @@ class ModelManager(t.Generic[M]):
:param search: Whether this model will provide results in global search.
"""

item_factory: t.ClassVar[ResolverFactory] = ItemResolver
"""Callable that takes the model class and creates the resolver callable for
:attr:`item_field`.
.. versionadded:: 1.1
"""

list_factory: t.ClassVar[ResolverFactory] = ListResolver
"""Callable that takes the model class and creates the resolver callable for
:attr:`list_field`.
.. versionadded:: 1.1
"""

create_factory: t.ClassVar[ResolverFactory] = CreateResolver
"""Callable that takes the model class and creates the resolver callable for
:attr:`create_field`.
.. versionadded:: 1.1
"""

update_factory: t.ClassVar[ResolverFactory] = UpdateResolver
"""Callable that takes the model class and creates the resolver callable for
:attr:`update_field`.
.. versionadded:: 1.1
"""

delete_factory: t.ClassVar[ResolverFactory] = DeleteResolver
"""Callable that takes the model class and creates the resolver callable for
:attr:`delete_field`.
.. versionadded:: 1.1
"""

model: type[M]
"""The SQLAlchemy model class."""

Expand Down Expand Up @@ -243,7 +284,7 @@ def __init__(self, model: type[M], search: bool = False) -> None:
self.item_field = magql.Field(
object,
args={"id": magql.Argument(pk_type.non_null)},
resolve=ItemResolver(model),
resolve=self.item_factory(model),
)
self.list_result = magql.Object(
f"{model_name}ListResult",
Expand All @@ -260,7 +301,7 @@ def __init__(self, model: type[M], search: bool = False) -> None:
"page": magql.Argument(magql.Int, validators=[validate_page]),
"per_page": magql.Argument(magql.Int, validators=[PerPageValidator()]),
},
resolve=ListResolver(model),
resolve=self.list_factory(model),
)
unique_validators = []
local_table = t.cast(sa.Table, mapper.local_table)
Expand All @@ -281,19 +322,19 @@ def __init__(self, model: type[M], search: bool = False) -> None:
self.create_field = magql.Field(
self.object.non_null,
args=create_args, # type: ignore[arg-type]
resolve=CreateResolver(model),
resolve=self.create_factory(model),
validators=[*unique_validators],
)
self.update_field = magql.Field(
self.object.non_null,
args=update_args, # type: ignore[arg-type]
resolve=UpdateResolver(model),
resolve=self.update_factory(model),
validators=[*unique_validators],
)
self.delete_field = magql.Field(
magql.Boolean.non_null,
args={pk_name: magql.Argument(pk_type.non_null, validators=[item_exists])},
resolve=DeleteResolver(model),
resolve=self.delete_factory(model),
)

if search:
Expand Down

0 comments on commit e389fbb

Please sign in to comment.