Skip to content

Commit

Permalink
refactor: Move BaseMiddleware to _internal module.
Browse files Browse the repository at this point in the history
  • Loading branch information
DABND19 committed Jan 11, 2025
1 parent 4650b4b commit 816affa
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 125 deletions.
117 changes: 117 additions & 0 deletions faststream/_internal/middlewares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from collections.abc import Awaitable
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional

from typing_extensions import Self

from faststream._internal.types import PublishCommandType

if TYPE_CHECKING:
from types import TracebackType

from faststream._internal.basic_types import AsyncFuncAny
from faststream._internal.context.repository import ContextRepo
from faststream.message import StreamMessage


class BaseMiddleware(Generic[PublishCommandType]):
"""A base middleware class."""

def __init__(
self,
msg: Optional[Any],
/,
*,
context: "ContextRepo",
) -> None:
self.msg = msg
self.context = context

async def on_receive(self) -> None:
"""Hook to call on message receive."""

async def after_processed(
self,
exc_type: Optional[type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional["TracebackType"] = None,
) -> Optional[bool]:
"""Asynchronously called after processing."""
return False

async def __aenter__(self) -> Self:
await self.on_receive()
return self

async def __aexit__(
self,
exc_type: Optional[type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional["TracebackType"] = None,
) -> Optional[bool]:
"""Exit the asynchronous context manager."""
return await self.after_processed(exc_type, exc_val, exc_tb)

async def on_consume(
self,
msg: "StreamMessage[Any]",
) -> "StreamMessage[Any]":
"""This option was deprecated and will be removed in 0.7.0. Please, use `consume_scope` instead."""
return msg

async def after_consume(self, err: Optional[Exception]) -> None:
"""This option was deprecated and will be removed in 0.7.0. Please, use `consume_scope` instead."""
if err is not None:
raise err

async def consume_scope(
self,
call_next: "AsyncFuncAny",
msg: "StreamMessage[Any]",
) -> Any:
"""Asynchronously consumes a message and returns an asynchronous iterator of decoded messages."""
err: Optional[Exception] = None
try:
result = await call_next(await self.on_consume(msg))

except Exception as e:
err = e

else:
return result

finally:
await self.after_consume(err)

async def on_publish(
self,
msg: PublishCommandType,
) -> PublishCommandType:
"""This option was deprecated and will be removed in 0.7.0. Please, use `publish_scope` instead."""
return msg

async def after_publish(
self,
err: Optional[Exception],
) -> None:
"""This option was deprecated and will be removed in 0.7.0. Please, use `publish_scope` instead."""
if err is not None:
raise err

async def publish_scope(
self,
call_next: Callable[[PublishCommandType], Awaitable[Any]],
cmd: PublishCommandType,
) -> Any:
"""Publish a message and return an async iterator."""
err: Optional[Exception] = None
try:
result = await call_next(await self.on_publish(cmd))

except Exception as e:
err = e

else:
return result

finally:
await self.after_publish(err)
6 changes: 6 additions & 0 deletions faststream/_internal/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing_extensions import (
ParamSpec,
TypeAlias,
TypeVar as TypeVar313,
)

from faststream._internal.basic_types import AsyncFuncAny
Expand All @@ -23,6 +24,11 @@
Msg_contra = TypeVar("Msg_contra", contravariant=True)
StreamMsg = TypeVar("StreamMsg", bound=StreamMessage[Any])
ConnectionType = TypeVar("ConnectionType")
PublishCommandType = TypeVar313(
"PublishCommandType",
bound=PublishCommand,
default=PublishCommand,
)

SyncFilter: TypeAlias = Callable[[StreamMsg], bool]
AsyncFilter: TypeAlias = Callable[[StreamMsg], Awaitable[bool]]
Expand Down
2 changes: 1 addition & 1 deletion faststream/middlewares/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from faststream._internal.middlewares import BaseMiddleware
from faststream.middlewares.acknowledgement.conf import AckPolicy
from faststream.middlewares.acknowledgement.middleware import AcknowledgementMiddleware
from faststream.middlewares.base import BaseMiddleware
from faststream.middlewares.exception import ExceptionMiddleware

__all__ = (
Expand Down
129 changes: 5 additions & 124 deletions faststream/middlewares/base.py
Original file line number Diff line number Diff line change
@@ -1,127 +1,8 @@
from collections.abc import Awaitable
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional
# TODO: Remove this file
from faststream._internal.middlewares import BaseMiddleware
from faststream._internal.types import PublishCommandType

from typing_extensions import (
Self,
TypeVar as TypeVar313,
)

from faststream.response import PublishCommand

if TYPE_CHECKING:
from types import TracebackType

from faststream._internal.basic_types import AsyncFuncAny
from faststream._internal.context.repository import ContextRepo
from faststream.message import StreamMessage


PublishCommandType = TypeVar313(
__all__ = (
"BaseMiddleware",
"PublishCommandType",
bound=PublishCommand,
default=PublishCommand,
)


class BaseMiddleware(Generic[PublishCommandType]):
"""A base middleware class."""

def __init__(
self,
msg: Optional[Any],
/,
*,
context: "ContextRepo",
) -> None:
self.msg = msg
self.context = context

async def on_receive(self) -> None:
"""Hook to call on message receive."""

async def after_processed(
self,
exc_type: Optional[type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional["TracebackType"] = None,
) -> Optional[bool]:
"""Asynchronously called after processing."""
return False

async def __aenter__(self) -> Self:
await self.on_receive()
return self

async def __aexit__(
self,
exc_type: Optional[type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional["TracebackType"] = None,
) -> Optional[bool]:
"""Exit the asynchronous context manager."""
return await self.after_processed(exc_type, exc_val, exc_tb)

async def on_consume(
self,
msg: "StreamMessage[Any]",
) -> "StreamMessage[Any]":
"""This option was deprecated and will be removed in 0.7.0. Please, use `consume_scope` instead."""
return msg

async def after_consume(self, err: Optional[Exception]) -> None:
"""This option was deprecated and will be removed in 0.7.0. Please, use `consume_scope` instead."""
if err is not None:
raise err

async def consume_scope(
self,
call_next: "AsyncFuncAny",
msg: "StreamMessage[Any]",
) -> Any:
"""Asynchronously consumes a message and returns an asynchronous iterator of decoded messages."""
err: Optional[Exception] = None
try:
result = await call_next(await self.on_consume(msg))

except Exception as e:
err = e

else:
return result

finally:
await self.after_consume(err)

async def on_publish(
self,
msg: PublishCommandType,
) -> PublishCommandType:
"""This option was deprecated and will be removed in 0.7.0. Please, use `publish_scope` instead."""
return msg

async def after_publish(
self,
err: Optional[Exception],
) -> None:
"""This option was deprecated and will be removed in 0.7.0. Please, use `publish_scope` instead."""
if err is not None:
raise err

async def publish_scope(
self,
call_next: Callable[[PublishCommandType], Awaitable[Any]],
cmd: PublishCommandType,
) -> Any:
"""Publish a message and return an async iterator."""
err: Optional[Exception] = None
try:
result = await call_next(await self.on_publish(cmd))

except Exception as e:
err = e

else:
return result

finally:
await self.after_publish(err)

0 comments on commit 816affa

Please sign in to comment.