Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions cadence/_internal/decision_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Callable
from typing import Dict, List, Optional, Callable, TypedDict, Literal

from cadence.api.v1 import (
decision_pb2 as decision,
Expand Down Expand Up @@ -81,12 +81,18 @@ def __str__(self) -> str:
@dataclass
class StateTransition:
"""Represents a state transition with associated actions."""
next_state: DecisionState
next_state: Optional[DecisionState]
action: Optional[Callable[['BaseDecisionStateMachine', history.HistoryEvent], None]] = None
condition: Optional[Callable[['BaseDecisionStateMachine', history.HistoryEvent], bool]] = None


decision_state_transition_map = {
class TransitionInfo(TypedDict):
type: Literal["initiated", "started", "completion", "canceled", "cancel_initiated", "cancel_failed", "initiation_failed"]
decision_type: DecisionType
transition: StateTransition


decision_state_transition_map: Dict[str, TransitionInfo] = {
"activity_task_scheduled_event_attributes": {
"type": "initiated",
"decision_type": DecisionType.ACTIVITY,
Expand Down Expand Up @@ -247,6 +253,10 @@ class BaseDecisionStateMachine:
Subclasses are responsible for mapping workflow history events into state
transitions and producing the next set of decisions when queried.
"""

# Common fields that subclasses may use
scheduled_event_id: Optional[int] = None
started_event_id: Optional[int] = None

def get_id(self) -> str:
raise NotImplementedError
Expand Down Expand Up @@ -890,12 +900,12 @@ def handle_history_event(self, event: history.HistoryEvent) -> None:
if transition_info:
event_type = transition_info["type"]
# Route to all relevant machines using the new unified handle_event method
for m in list(self.activities.values()):
m.handle_event(event, event_type)
for m in list(self.timers.values()):
m.handle_event(event, event_type)
for m in list(self.children.values()):
m.handle_event(event, event_type)
for activity_machine in list(self.activities.values()):
activity_machine.handle_event(event, event_type)
for timer_machine in list(self.timers.values()):
timer_machine.handle_event(event, event_type)
for child_machine in list(self.children.values()):
child_machine.handle_event(event, event_type)

# ----- Decision aggregation -----

Expand All @@ -907,11 +917,11 @@ def collect_pending_decisions(self) -> List[decision.Decision]:
decisions.extend(machine.collect_pending_decisions())

# Timers
for machine in list(self.timers.values()):
decisions.extend(machine.collect_pending_decisions())
for timer_machine in list(self.timers.values()):
decisions.extend(timer_machine.collect_pending_decisions())

# Children
for machine in list(self.children.values()):
decisions.extend(machine.collect_pending_decisions())
for child_machine in list(self.children.values()):
decisions.extend(child_machine.collect_pending_decisions())

return decisions
8 changes: 7 additions & 1 deletion cadence/_internal/rpc/metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
from typing import Any, Callable

from grpc.aio import Metadata
from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails
Expand All @@ -16,7 +17,12 @@ class MetadataInterceptor(UnaryUnaryClientInterceptor):
def __init__(self, metadata: Metadata):
self._metadata = metadata

async def intercept_unary_unary(self, continuation, client_call_details: ClientCallDetails, request):
async def intercept_unary_unary(
self,
continuation: Callable[[ClientCallDetails, Any], Any],
client_call_details: ClientCallDetails,
request: Any
) -> Any:
return await continuation(self._replace_details(client_call_details), request)


Expand Down
4 changes: 2 additions & 2 deletions cadence/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class ClientOptions(TypedDict, total=False):
identity: str

class Client:
def __init__(self, channel: Channel, options: ClientOptions):
def __init__(self, channel: Channel, options: ClientOptions) -> None:
self._channel = channel
self._worker_stub = WorkerAPIStub(channel)
self._options = options
Expand All @@ -31,7 +31,7 @@ def worker_stub(self) -> WorkerAPIStub:
return self._worker_stub


async def close(self):
async def close(self) -> None:
await self._channel.close()


6 changes: 3 additions & 3 deletions cadence/data_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def to_data(self, values: List[Any]) -> Payload:
raise NotImplementedError()

class DefaultDataConverter(DataConverter):
def __init__(self):
def __init__(self) -> None:
self._encoder = json.Encoder()
self._decoder = json.Decoder()
self._fallback_decoder = JSONDecoder(strict=False)
Expand All @@ -38,7 +38,7 @@ async def from_data(self, payload: Payload, type_hints: List[Type]) -> List[Any]


def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type]) -> List[Any]:
results = []
results: List[Any] = []
start, end = 0, len(payload)
while start < end and len(results) < len(type_hints):
remaining = payload[start:end]
Expand All @@ -50,7 +50,7 @@ def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type]) ->

@staticmethod
def _convert_into(values: List[Any], type_hints: List[Type]) -> List[Any]:
results = []
results: List[Any] = []
for i, type_hint in enumerate(type_hints):
if i < len(values):
value = convert(values[i], type_hint)
Expand Down
6 changes: 3 additions & 3 deletions cadence/worker/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@


class ActivityWorker:
def __init__(self, client: Client, task_list: str, options: WorkerOptions):
def __init__(self, client: Client, task_list: str, options: WorkerOptions) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought no value implied -> None, I'm sorry :(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does implied -> None, but MyPy does not like that.

self._client = client
self._task_list = task_list
self._identity = options["identity"]
permits = asyncio.Semaphore(options["max_concurrent_activity_execution_size"])
self._poller = Poller[PollForActivityTaskResponse](options["activity_task_pollers"], permits, self._poll, self._execute)
# TODO: Local dispatch, local activities, actually running activities, etc

async def run(self):
async def run(self) -> None:
await self._poller.run()

async def _poll(self) -> Optional[PollForActivityTaskResponse]:
Expand All @@ -34,7 +34,7 @@ async def _poll(self) -> Optional[PollForActivityTaskResponse]:
else:
return None

async def _execute(self, task: PollForActivityTaskResponse):
async def _execute(self, task: PollForActivityTaskResponse) -> None:
await self._client.worker_stub.RespondActivityTaskFailed(RespondActivityTaskFailedRequest(
task_token=task.task_token,
identity=self._identity,
Expand Down
6 changes: 3 additions & 3 deletions cadence/worker/_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@


class DecisionWorker:
def __init__(self, client: Client, task_list: str, options: WorkerOptions):
def __init__(self, client: Client, task_list: str, options: WorkerOptions) -> None:
self._client = client
self._task_list = task_list
self._identity = options["identity"]
permits = asyncio.Semaphore(options["max_concurrent_decision_task_execution_size"])
self._poller = Poller[PollForDecisionTaskResponse](options["decision_task_pollers"], permits, self._poll, self._execute)
# TODO: Sticky poller, actually running workflows, etc.

async def run(self):
async def run(self) -> None:
await self._poller.run()

async def _poll(self) -> Optional[PollForDecisionTaskResponse]:
Expand All @@ -36,7 +36,7 @@ async def _poll(self) -> Optional[PollForDecisionTaskResponse]:
return None


async def _execute(self, task: PollForDecisionTaskResponse):
async def _execute(self, task: PollForDecisionTaskResponse) -> None:
await self._client.worker_stub.RespondDecisionTaskFailed(RespondDecisionTaskFailedRequest(
task_token=task.task_token,
cause=DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION,
Expand Down
13 changes: 6 additions & 7 deletions cadence/worker/_poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,23 @@
T = TypeVar('T')

class Poller(Generic[T]):
def __init__(self, num_tasks: int, permits: asyncio.Semaphore, poll: Callable[[], Awaitable[Optional[T]]], callback: Callable[[T], Awaitable[None]]):
def __init__(self, num_tasks: int, permits: asyncio.Semaphore, poll: Callable[[], Awaitable[Optional[T]]], callback: Callable[[T], Awaitable[None]]) -> None:
self._num_tasks = num_tasks
self._permits = permits
self._poll = poll
self._callback = callback
self._background_tasks: set[asyncio.Task[None]] = set()
pass

async def run(self):
async def run(self) -> None:
try:
async with asyncio.TaskGroup() as tg:
for i in range(self._num_tasks):
tg.create_task(self._poll_loop())
except asyncio.CancelledError:
pass
pass


async def _poll_loop(self):
async def _poll_loop(self) -> None:
while True:
try:
await self._poll_and_dispatch()
Expand All @@ -34,7 +33,7 @@ async def _poll_loop(self):
logger.exception('Exception while polling')


async def _poll_and_dispatch(self):
async def _poll_and_dispatch(self) -> None:
await self._permits.acquire()
try:
task = await self._poll()
Expand All @@ -51,7 +50,7 @@ async def _poll_and_dispatch(self):
self._background_tasks.add(scheduled)
scheduled.add_done_callback(self._background_tasks.remove)

async def _execute_callback(self, task: T):
async def _execute_callback(self, task: T) -> None:
try:
await self._callback(task)
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion cadence/worker/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Registry:
workflows and activities in a Cadence application.
"""

def __init__(self):
def __init__(self) -> None:
"""Initialize the registry."""
self._workflows: Dict[str, Callable] = {}
self._activities: Dict[str, Callable] = {}
Expand Down
14 changes: 7 additions & 7 deletions cadence/worker/_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import uuid
from typing import Unpack
from typing import Unpack, cast

from cadence.client import Client
from cadence.worker._activity import ActivityWorker
Expand All @@ -10,7 +10,7 @@

class Worker:

def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOptions]):
def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOptions]) -> None:
self._client = client
self._task_list = task_list

Expand All @@ -21,7 +21,7 @@ def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOption
self._decision_worker = DecisionWorker(client, task_list, options)


async def run(self):
async def run(self) -> None:
async with asyncio.TaskGroup() as tg:
if not self._options["disable_workflow_worker"]:
tg.create_task(self._decision_worker.run())
Expand All @@ -30,13 +30,13 @@ async def run(self):



def _validate_and_copy_defaults(client: Client, task_list: str, options: WorkerOptions):
def _validate_and_copy_defaults(client: Client, task_list: str, options: WorkerOptions) -> None:
if "identity" not in options:
options["identity"] = f"{client.identity}@{task_list}@{uuid.uuid4()}"

# TODO: More validation

for (key, value) in _DEFAULT_WORKER_OPTIONS.items():
# Set default values for missing options
for key, value in _DEFAULT_WORKER_OPTIONS.items():
if key not in options:
# noinspection PyTypedDict
options[key] = value
cast(dict, options)[key] = value
15 changes: 7 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,15 @@ warn_no_return = true
warn_unreachable = true
strict_equality = true
explicit_package_bases = true
disable_error_code = [
"var-annotated",
"arg-type",
"attr-defined",
"assignment",
"literal-required",
]
exclude = [
"cadence/api/*", # Exclude entire api directory with generated proto files
"cadence/api",
"cadence/api/.*",
"cadence/sample",
]

# Reduce recursive module checking
follow_imports = "silent"

[[tool.mypy.overrides]]
module = [
"grpcio.*",
Expand All @@ -133,6 +131,7 @@ module = [
"google.protobuf.*",
"uber.cadence.*",
"msgspec.*",
"cadence.api.*",
]
ignore_missing_imports = true

Expand Down
22 changes: 11 additions & 11 deletions tests/cadence/data_converter_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Any, Type
from typing import Any, Type, Optional

import pytest

Expand All @@ -8,10 +8,10 @@
from msgspec import json

@dataclasses.dataclass
class TestDataClass:
class _TestDataClass:
foo: str = "foo"
bar: int = -1
baz: 'TestDataClass' = None
baz: Optional['_TestDataClass'] = None

@pytest.mark.parametrize(
"json,types,expected",
Expand All @@ -35,7 +35,7 @@ class TestDataClass:
"[true]", [bool, bool], [True, False], id="bools"
),
pytest.param(
'[{"foo": "hello world", "bar": 42, "baz": {"bar": 43}}]', [TestDataClass, TestDataClass], [TestDataClass("hello world", 42, TestDataClass(bar=43)), None], id="data classes"
'[{"foo": "hello world", "bar": 42, "baz": {"bar": 43}}]', [_TestDataClass, _TestDataClass], [_TestDataClass("hello world", 42, _TestDataClass(bar=43)), None], id="data classes"
),
pytest.param(
'[{"foo": "hello world"}]', [dict, dict], [{"foo": "hello world"}, None], id="dicts"
Expand All @@ -53,17 +53,17 @@ class TestDataClass:
'["hello", "world"]', [list[str]], [["hello", "world"]], id="list"
),
pytest.param(
'{"foo": "bar"} {"bar": 100} ["hello"] "world"', [TestDataClass, TestDataClass, list[str], str],
[TestDataClass(foo="bar"), TestDataClass(bar=100), ["hello"], "world"], id="space delimited mix"
'{"foo": "bar"} {"bar": 100} ["hello"] "world"', [_TestDataClass, _TestDataClass, list[str], str],
[_TestDataClass(foo="bar"), _TestDataClass(bar=100), ["hello"], "world"], id="space delimited mix"
),
pytest.param(
'[{"foo": "bar"},{"bar": 100},["hello"],"world"]', [TestDataClass, TestDataClass, list[str], str],
[TestDataClass(foo="bar"), TestDataClass(bar=100), ["hello"], "world"], id="json array mix"
'[{"foo": "bar"},{"bar": 100},["hello"],"world"]', [_TestDataClass, _TestDataClass, list[str], str],
[_TestDataClass(foo="bar"), _TestDataClass(bar=100), ["hello"], "world"], id="json array mix"
),
]
)
@pytest.mark.asyncio
async def test_data_converter_from_data(json: str, types: list[Type], expected: list[Any]):
async def test_data_converter_from_data(json: str, types: list[Type], expected: list[Any]) -> None:
converter = DefaultDataConverter()
actual = await converter.from_data(Payload(data=json.encode()), types)
assert expected == actual
Expand All @@ -78,12 +78,12 @@ async def test_data_converter_from_data(json: str, types: list[Type], expected:
["hello", "world"], '["hello", "world"]', id="multiple values"
),
pytest.param(
[TestDataClass()], '{"foo": "foo", "bar": -1, "baz": null}', id="data classes"
[_TestDataClass()], '{"foo": "foo", "bar": -1, "baz": null}', id="data classes"
),
]
)
@pytest.mark.asyncio
async def test_data_converter_to_data(values: list[Any], expected: str):
async def test_data_converter_to_data(values: list[Any], expected: str) -> None:
converter = DefaultDataConverter()
actual = await converter.to_data(values)
# Parse both rather than trying to compare strings
Expand Down