diff --git a/cadence/_internal/decision_state_machine.py b/cadence/_internal/decision_state_machine.py index 245ba4e..8b721dc 100644 --- a/cadence/_internal/decision_state_machine.py +++ b/cadence/_internal/decision_state_machine.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List, Optional, Callable +from typing import Dict, List, Optional, Callable, TypedDict, Literal from cadence.api.v1 import ( decision_pb2 as decision, @@ -81,12 +81,18 @@ def __str__(self) -> str: @dataclass class StateTransition: """Represents a state transition with associated actions.""" - next_state: DecisionState + next_state: Optional[DecisionState] action: Optional[Callable[['BaseDecisionStateMachine', history.HistoryEvent], None]] = None condition: Optional[Callable[['BaseDecisionStateMachine', history.HistoryEvent], bool]] = None -decision_state_transition_map = { +class TransitionInfo(TypedDict): + type: Literal["initiated", "started", "completion", "canceled", "cancel_initiated", "cancel_failed", "initiation_failed"] + decision_type: DecisionType + transition: StateTransition + + +decision_state_transition_map: Dict[str, TransitionInfo] = { "activity_task_scheduled_event_attributes": { "type": "initiated", "decision_type": DecisionType.ACTIVITY, @@ -247,6 +253,10 @@ class BaseDecisionStateMachine: Subclasses are responsible for mapping workflow history events into state transitions and producing the next set of decisions when queried. """ + + # Common fields that subclasses may use + scheduled_event_id: Optional[int] = None + started_event_id: Optional[int] = None def get_id(self) -> str: raise NotImplementedError @@ -890,12 +900,12 @@ def handle_history_event(self, event: history.HistoryEvent) -> None: if transition_info: event_type = transition_info["type"] # Route to all relevant machines using the new unified handle_event method - for m in list(self.activities.values()): - m.handle_event(event, event_type) - for m in list(self.timers.values()): - m.handle_event(event, event_type) - for m in list(self.children.values()): - m.handle_event(event, event_type) + for activity_machine in list(self.activities.values()): + activity_machine.handle_event(event, event_type) + for timer_machine in list(self.timers.values()): + timer_machine.handle_event(event, event_type) + for child_machine in list(self.children.values()): + child_machine.handle_event(event, event_type) # ----- Decision aggregation ----- @@ -907,11 +917,11 @@ def collect_pending_decisions(self) -> List[decision.Decision]: decisions.extend(machine.collect_pending_decisions()) # Timers - for machine in list(self.timers.values()): - decisions.extend(machine.collect_pending_decisions()) + for timer_machine in list(self.timers.values()): + decisions.extend(timer_machine.collect_pending_decisions()) # Children - for machine in list(self.children.values()): - decisions.extend(machine.collect_pending_decisions()) + for child_machine in list(self.children.values()): + decisions.extend(child_machine.collect_pending_decisions()) return decisions diff --git a/cadence/_internal/rpc/metadata.py b/cadence/_internal/rpc/metadata.py index e4c1fe3..c46b909 100644 --- a/cadence/_internal/rpc/metadata.py +++ b/cadence/_internal/rpc/metadata.py @@ -1,4 +1,5 @@ import collections +from typing import Any, Callable from grpc.aio import Metadata from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails @@ -16,7 +17,12 @@ class MetadataInterceptor(UnaryUnaryClientInterceptor): def __init__(self, metadata: Metadata): self._metadata = metadata - async def intercept_unary_unary(self, continuation, client_call_details: ClientCallDetails, request): + async def intercept_unary_unary( + self, + continuation: Callable[[ClientCallDetails, Any], Any], + client_call_details: ClientCallDetails, + request: Any + ) -> Any: return await continuation(self._replace_details(client_call_details), request) diff --git a/cadence/client.py b/cadence/client.py index 0eccd17..bfca66b 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -11,7 +11,7 @@ class ClientOptions(TypedDict, total=False): identity: str class Client: - def __init__(self, channel: Channel, options: ClientOptions): + def __init__(self, channel: Channel, options: ClientOptions) -> None: self._channel = channel self._worker_stub = WorkerAPIStub(channel) self._options = options @@ -31,7 +31,7 @@ def worker_stub(self) -> WorkerAPIStub: return self._worker_stub - async def close(self): + async def close(self) -> None: await self._channel.close() diff --git a/cadence/data_converter.py b/cadence/data_converter.py index 819fd90..ca54712 100644 --- a/cadence/data_converter.py +++ b/cadence/data_converter.py @@ -17,7 +17,7 @@ async def to_data(self, values: List[Any]) -> Payload: raise NotImplementedError() class DefaultDataConverter(DataConverter): - def __init__(self): + def __init__(self) -> None: self._encoder = json.Encoder() self._decoder = json.Decoder() self._fallback_decoder = JSONDecoder(strict=False) @@ -38,7 +38,7 @@ async def from_data(self, payload: Payload, type_hints: List[Type]) -> List[Any] def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type]) -> List[Any]: - results = [] + results: List[Any] = [] start, end = 0, len(payload) while start < end and len(results) < len(type_hints): remaining = payload[start:end] @@ -50,7 +50,7 @@ def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type]) -> @staticmethod def _convert_into(values: List[Any], type_hints: List[Type]) -> List[Any]: - results = [] + results: List[Any] = [] for i, type_hint in enumerate(type_hints): if i < len(values): value = convert(values[i], type_hint) diff --git a/cadence/worker/_activity.py b/cadence/worker/_activity.py index 0ae24e1..ec1ef7e 100644 --- a/cadence/worker/_activity.py +++ b/cadence/worker/_activity.py @@ -11,7 +11,7 @@ class ActivityWorker: - def __init__(self, client: Client, task_list: str, options: WorkerOptions): + def __init__(self, client: Client, task_list: str, options: WorkerOptions) -> None: self._client = client self._task_list = task_list self._identity = options["identity"] @@ -19,7 +19,7 @@ def __init__(self, client: Client, task_list: str, options: WorkerOptions): self._poller = Poller[PollForActivityTaskResponse](options["activity_task_pollers"], permits, self._poll, self._execute) # TODO: Local dispatch, local activities, actually running activities, etc - async def run(self): + async def run(self) -> None: await self._poller.run() async def _poll(self) -> Optional[PollForActivityTaskResponse]: @@ -34,7 +34,7 @@ async def _poll(self) -> Optional[PollForActivityTaskResponse]: else: return None - async def _execute(self, task: PollForActivityTaskResponse): + async def _execute(self, task: PollForActivityTaskResponse) -> None: await self._client.worker_stub.RespondActivityTaskFailed(RespondActivityTaskFailedRequest( task_token=task.task_token, identity=self._identity, diff --git a/cadence/worker/_decision.py b/cadence/worker/_decision.py index 0510f61..47e0817 100644 --- a/cadence/worker/_decision.py +++ b/cadence/worker/_decision.py @@ -12,7 +12,7 @@ class DecisionWorker: - def __init__(self, client: Client, task_list: str, options: WorkerOptions): + def __init__(self, client: Client, task_list: str, options: WorkerOptions) -> None: self._client = client self._task_list = task_list self._identity = options["identity"] @@ -20,7 +20,7 @@ def __init__(self, client: Client, task_list: str, options: WorkerOptions): self._poller = Poller[PollForDecisionTaskResponse](options["decision_task_pollers"], permits, self._poll, self._execute) # TODO: Sticky poller, actually running workflows, etc. - async def run(self): + async def run(self) -> None: await self._poller.run() async def _poll(self) -> Optional[PollForDecisionTaskResponse]: @@ -36,7 +36,7 @@ async def _poll(self) -> Optional[PollForDecisionTaskResponse]: return None - async def _execute(self, task: PollForDecisionTaskResponse): + async def _execute(self, task: PollForDecisionTaskResponse) -> None: await self._client.worker_stub.RespondDecisionTaskFailed(RespondDecisionTaskFailedRequest( task_token=task.task_token, cause=DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION, diff --git a/cadence/worker/_poller.py b/cadence/worker/_poller.py index 3b2889b..a185d27 100644 --- a/cadence/worker/_poller.py +++ b/cadence/worker/_poller.py @@ -7,24 +7,23 @@ T = TypeVar('T') class Poller(Generic[T]): - def __init__(self, num_tasks: int, permits: asyncio.Semaphore, poll: Callable[[], Awaitable[Optional[T]]], callback: Callable[[T], Awaitable[None]]): + def __init__(self, num_tasks: int, permits: asyncio.Semaphore, poll: Callable[[], Awaitable[Optional[T]]], callback: Callable[[T], Awaitable[None]]) -> None: self._num_tasks = num_tasks self._permits = permits self._poll = poll self._callback = callback self._background_tasks: set[asyncio.Task[None]] = set() - pass - async def run(self): + async def run(self) -> None: try: async with asyncio.TaskGroup() as tg: for i in range(self._num_tasks): tg.create_task(self._poll_loop()) except asyncio.CancelledError: - pass + pass - async def _poll_loop(self): + async def _poll_loop(self) -> None: while True: try: await self._poll_and_dispatch() @@ -34,7 +33,7 @@ async def _poll_loop(self): logger.exception('Exception while polling') - async def _poll_and_dispatch(self): + async def _poll_and_dispatch(self) -> None: await self._permits.acquire() try: task = await self._poll() @@ -51,7 +50,7 @@ async def _poll_and_dispatch(self): self._background_tasks.add(scheduled) scheduled.add_done_callback(self._background_tasks.remove) - async def _execute_callback(self, task: T): + async def _execute_callback(self, task: T) -> None: try: await self._callback(task) except Exception: diff --git a/cadence/worker/_registry.py b/cadence/worker/_registry.py index 6822351..4ba0972 100644 --- a/cadence/worker/_registry.py +++ b/cadence/worker/_registry.py @@ -33,7 +33,7 @@ class Registry: workflows and activities in a Cadence application. """ - def __init__(self): + def __init__(self) -> None: """Initialize the registry.""" self._workflows: Dict[str, Callable] = {} self._activities: Dict[str, Callable] = {} diff --git a/cadence/worker/_worker.py b/cadence/worker/_worker.py index bb3ccc3..8d8932a 100644 --- a/cadence/worker/_worker.py +++ b/cadence/worker/_worker.py @@ -1,6 +1,6 @@ import asyncio import uuid -from typing import Unpack +from typing import Unpack, cast from cadence.client import Client from cadence.worker._activity import ActivityWorker @@ -10,7 +10,7 @@ class Worker: - def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOptions]): + def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOptions]) -> None: self._client = client self._task_list = task_list @@ -21,7 +21,7 @@ def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOption self._decision_worker = DecisionWorker(client, task_list, options) - async def run(self): + async def run(self) -> None: async with asyncio.TaskGroup() as tg: if not self._options["disable_workflow_worker"]: tg.create_task(self._decision_worker.run()) @@ -30,13 +30,13 @@ async def run(self): -def _validate_and_copy_defaults(client: Client, task_list: str, options: WorkerOptions): +def _validate_and_copy_defaults(client: Client, task_list: str, options: WorkerOptions) -> None: if "identity" not in options: options["identity"] = f"{client.identity}@{task_list}@{uuid.uuid4()}" # TODO: More validation - for (key, value) in _DEFAULT_WORKER_OPTIONS.items(): + # Set default values for missing options + for key, value in _DEFAULT_WORKER_OPTIONS.items(): if key not in options: - # noinspection PyTypedDict - options[key] = value + cast(dict, options)[key] = value diff --git a/pyproject.toml b/pyproject.toml index 7eb7aed..c154d3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,17 +113,15 @@ warn_no_return = true warn_unreachable = true strict_equality = true explicit_package_bases = true -disable_error_code = [ - "var-annotated", - "arg-type", - "attr-defined", - "assignment", - "literal-required", -] exclude = [ - "cadence/api/*", # Exclude entire api directory with generated proto files + "cadence/api", + "cadence/api/.*", + "cadence/sample", ] +# Reduce recursive module checking +follow_imports = "silent" + [[tool.mypy.overrides]] module = [ "grpcio.*", @@ -133,6 +131,7 @@ module = [ "google.protobuf.*", "uber.cadence.*", "msgspec.*", + "cadence.api.*", ] ignore_missing_imports = true diff --git a/tests/cadence/data_converter_test.py b/tests/cadence/data_converter_test.py index 2f55395..88aecd3 100644 --- a/tests/cadence/data_converter_test.py +++ b/tests/cadence/data_converter_test.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Type +from typing import Any, Type, Optional import pytest @@ -8,10 +8,10 @@ from msgspec import json @dataclasses.dataclass -class TestDataClass: +class _TestDataClass: foo: str = "foo" bar: int = -1 - baz: 'TestDataClass' = None + baz: Optional['_TestDataClass'] = None @pytest.mark.parametrize( "json,types,expected", @@ -35,7 +35,7 @@ class TestDataClass: "[true]", [bool, bool], [True, False], id="bools" ), pytest.param( - '[{"foo": "hello world", "bar": 42, "baz": {"bar": 43}}]', [TestDataClass, TestDataClass], [TestDataClass("hello world", 42, TestDataClass(bar=43)), None], id="data classes" + '[{"foo": "hello world", "bar": 42, "baz": {"bar": 43}}]', [_TestDataClass, _TestDataClass], [_TestDataClass("hello world", 42, _TestDataClass(bar=43)), None], id="data classes" ), pytest.param( '[{"foo": "hello world"}]', [dict, dict], [{"foo": "hello world"}, None], id="dicts" @@ -53,17 +53,17 @@ class TestDataClass: '["hello", "world"]', [list[str]], [["hello", "world"]], id="list" ), pytest.param( - '{"foo": "bar"} {"bar": 100} ["hello"] "world"', [TestDataClass, TestDataClass, list[str], str], - [TestDataClass(foo="bar"), TestDataClass(bar=100), ["hello"], "world"], id="space delimited mix" + '{"foo": "bar"} {"bar": 100} ["hello"] "world"', [_TestDataClass, _TestDataClass, list[str], str], + [_TestDataClass(foo="bar"), _TestDataClass(bar=100), ["hello"], "world"], id="space delimited mix" ), pytest.param( - '[{"foo": "bar"},{"bar": 100},["hello"],"world"]', [TestDataClass, TestDataClass, list[str], str], - [TestDataClass(foo="bar"), TestDataClass(bar=100), ["hello"], "world"], id="json array mix" + '[{"foo": "bar"},{"bar": 100},["hello"],"world"]', [_TestDataClass, _TestDataClass, list[str], str], + [_TestDataClass(foo="bar"), _TestDataClass(bar=100), ["hello"], "world"], id="json array mix" ), ] ) @pytest.mark.asyncio -async def test_data_converter_from_data(json: str, types: list[Type], expected: list[Any]): +async def test_data_converter_from_data(json: str, types: list[Type], expected: list[Any]) -> None: converter = DefaultDataConverter() actual = await converter.from_data(Payload(data=json.encode()), types) assert expected == actual @@ -78,12 +78,12 @@ async def test_data_converter_from_data(json: str, types: list[Type], expected: ["hello", "world"], '["hello", "world"]', id="multiple values" ), pytest.param( - [TestDataClass()], '{"foo": "foo", "bar": -1, "baz": null}', id="data classes" + [_TestDataClass()], '{"foo": "foo", "bar": -1, "baz": null}', id="data classes" ), ] ) @pytest.mark.asyncio -async def test_data_converter_to_data(values: list[Any], expected: str): +async def test_data_converter_to_data(values: list[Any], expected: str) -> None: converter = DefaultDataConverter() actual = await converter.to_data(values) # Parse both rather than trying to compare strings