Skip to content

Commit 5d6533a

Browse files
authored
Enforce type safe check and fix code with type safe issues (#18)
1 parent 6ff1869 commit 5d6533a

File tree

11 files changed

+73
-59
lines changed

11 files changed

+73
-59
lines changed

cadence/_internal/decision_state_machine.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from dataclasses import dataclass, field
44
from enum import Enum
5-
from typing import Dict, List, Optional, Callable
5+
from typing import Dict, List, Optional, Callable, TypedDict, Literal
66

77
from cadence.api.v1 import (
88
decision_pb2 as decision,
@@ -81,12 +81,18 @@ def __str__(self) -> str:
8181
@dataclass
8282
class StateTransition:
8383
"""Represents a state transition with associated actions."""
84-
next_state: DecisionState
84+
next_state: Optional[DecisionState]
8585
action: Optional[Callable[['BaseDecisionStateMachine', history.HistoryEvent], None]] = None
8686
condition: Optional[Callable[['BaseDecisionStateMachine', history.HistoryEvent], bool]] = None
8787

8888

89-
decision_state_transition_map = {
89+
class TransitionInfo(TypedDict):
90+
type: Literal["initiated", "started", "completion", "canceled", "cancel_initiated", "cancel_failed", "initiation_failed"]
91+
decision_type: DecisionType
92+
transition: StateTransition
93+
94+
95+
decision_state_transition_map: Dict[str, TransitionInfo] = {
9096
"activity_task_scheduled_event_attributes": {
9197
"type": "initiated",
9298
"decision_type": DecisionType.ACTIVITY,
@@ -247,6 +253,10 @@ class BaseDecisionStateMachine:
247253
Subclasses are responsible for mapping workflow history events into state
248254
transitions and producing the next set of decisions when queried.
249255
"""
256+
257+
# Common fields that subclasses may use
258+
scheduled_event_id: Optional[int] = None
259+
started_event_id: Optional[int] = None
250260

251261
def get_id(self) -> str:
252262
raise NotImplementedError
@@ -890,12 +900,12 @@ def handle_history_event(self, event: history.HistoryEvent) -> None:
890900
if transition_info:
891901
event_type = transition_info["type"]
892902
# Route to all relevant machines using the new unified handle_event method
893-
for m in list(self.activities.values()):
894-
m.handle_event(event, event_type)
895-
for m in list(self.timers.values()):
896-
m.handle_event(event, event_type)
897-
for m in list(self.children.values()):
898-
m.handle_event(event, event_type)
903+
for activity_machine in list(self.activities.values()):
904+
activity_machine.handle_event(event, event_type)
905+
for timer_machine in list(self.timers.values()):
906+
timer_machine.handle_event(event, event_type)
907+
for child_machine in list(self.children.values()):
908+
child_machine.handle_event(event, event_type)
899909

900910
# ----- Decision aggregation -----
901911

@@ -907,11 +917,11 @@ def collect_pending_decisions(self) -> List[decision.Decision]:
907917
decisions.extend(machine.collect_pending_decisions())
908918

909919
# Timers
910-
for machine in list(self.timers.values()):
911-
decisions.extend(machine.collect_pending_decisions())
920+
for timer_machine in list(self.timers.values()):
921+
decisions.extend(timer_machine.collect_pending_decisions())
912922

913923
# Children
914-
for machine in list(self.children.values()):
915-
decisions.extend(machine.collect_pending_decisions())
924+
for child_machine in list(self.children.values()):
925+
decisions.extend(child_machine.collect_pending_decisions())
916926

917927
return decisions

cadence/_internal/rpc/metadata.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
from typing import Any, Callable
23

34
from grpc.aio import Metadata
45
from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails
@@ -16,7 +17,12 @@ class MetadataInterceptor(UnaryUnaryClientInterceptor):
1617
def __init__(self, metadata: Metadata):
1718
self._metadata = metadata
1819

19-
async def intercept_unary_unary(self, continuation, client_call_details: ClientCallDetails, request):
20+
async def intercept_unary_unary(
21+
self,
22+
continuation: Callable[[ClientCallDetails, Any], Any],
23+
client_call_details: ClientCallDetails,
24+
request: Any
25+
) -> Any:
2026
return await continuation(self._replace_details(client_call_details), request)
2127

2228

cadence/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class ClientOptions(TypedDict, total=False):
1111
identity: str
1212

1313
class Client:
14-
def __init__(self, channel: Channel, options: ClientOptions):
14+
def __init__(self, channel: Channel, options: ClientOptions) -> None:
1515
self._channel = channel
1616
self._worker_stub = WorkerAPIStub(channel)
1717
self._options = options
@@ -31,7 +31,7 @@ def worker_stub(self) -> WorkerAPIStub:
3131
return self._worker_stub
3232

3333

34-
async def close(self):
34+
async def close(self) -> None:
3535
await self._channel.close()
3636

3737

cadence/data_converter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ async def to_data(self, values: List[Any]) -> Payload:
1717
raise NotImplementedError()
1818

1919
class DefaultDataConverter(DataConverter):
20-
def __init__(self):
20+
def __init__(self) -> None:
2121
self._encoder = json.Encoder()
2222
self._decoder = json.Decoder()
2323
self._fallback_decoder = JSONDecoder(strict=False)
@@ -38,7 +38,7 @@ async def from_data(self, payload: Payload, type_hints: List[Type]) -> List[Any]
3838

3939

4040
def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type]) -> List[Any]:
41-
results = []
41+
results: List[Any] = []
4242
start, end = 0, len(payload)
4343
while start < end and len(results) < len(type_hints):
4444
remaining = payload[start:end]
@@ -50,7 +50,7 @@ def _decode_whitespace_delimited(self, payload: str, type_hints: List[Type]) ->
5050

5151
@staticmethod
5252
def _convert_into(values: List[Any], type_hints: List[Type]) -> List[Any]:
53-
results = []
53+
results: List[Any] = []
5454
for i, type_hint in enumerate(type_hints):
5555
if i < len(values):
5656
value = convert(values[i], type_hint)

cadence/worker/_activity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111

1212

1313
class ActivityWorker:
14-
def __init__(self, client: Client, task_list: str, options: WorkerOptions):
14+
def __init__(self, client: Client, task_list: str, options: WorkerOptions) -> None:
1515
self._client = client
1616
self._task_list = task_list
1717
self._identity = options["identity"]
1818
permits = asyncio.Semaphore(options["max_concurrent_activity_execution_size"])
1919
self._poller = Poller[PollForActivityTaskResponse](options["activity_task_pollers"], permits, self._poll, self._execute)
2020
# TODO: Local dispatch, local activities, actually running activities, etc
2121

22-
async def run(self):
22+
async def run(self) -> None:
2323
await self._poller.run()
2424

2525
async def _poll(self) -> Optional[PollForActivityTaskResponse]:
@@ -34,7 +34,7 @@ async def _poll(self) -> Optional[PollForActivityTaskResponse]:
3434
else:
3535
return None
3636

37-
async def _execute(self, task: PollForActivityTaskResponse):
37+
async def _execute(self, task: PollForActivityTaskResponse) -> None:
3838
await self._client.worker_stub.RespondActivityTaskFailed(RespondActivityTaskFailedRequest(
3939
task_token=task.task_token,
4040
identity=self._identity,

cadence/worker/_decision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212

1313

1414
class DecisionWorker:
15-
def __init__(self, client: Client, task_list: str, options: WorkerOptions):
15+
def __init__(self, client: Client, task_list: str, options: WorkerOptions) -> None:
1616
self._client = client
1717
self._task_list = task_list
1818
self._identity = options["identity"]
1919
permits = asyncio.Semaphore(options["max_concurrent_decision_task_execution_size"])
2020
self._poller = Poller[PollForDecisionTaskResponse](options["decision_task_pollers"], permits, self._poll, self._execute)
2121
# TODO: Sticky poller, actually running workflows, etc.
2222

23-
async def run(self):
23+
async def run(self) -> None:
2424
await self._poller.run()
2525

2626
async def _poll(self) -> Optional[PollForDecisionTaskResponse]:
@@ -36,7 +36,7 @@ async def _poll(self) -> Optional[PollForDecisionTaskResponse]:
3636
return None
3737

3838

39-
async def _execute(self, task: PollForDecisionTaskResponse):
39+
async def _execute(self, task: PollForDecisionTaskResponse) -> None:
4040
await self._client.worker_stub.RespondDecisionTaskFailed(RespondDecisionTaskFailedRequest(
4141
task_token=task.task_token,
4242
cause=DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION,

cadence/worker/_poller.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,23 @@
77
T = TypeVar('T')
88

99
class Poller(Generic[T]):
10-
def __init__(self, num_tasks: int, permits: asyncio.Semaphore, poll: Callable[[], Awaitable[Optional[T]]], callback: Callable[[T], Awaitable[None]]):
10+
def __init__(self, num_tasks: int, permits: asyncio.Semaphore, poll: Callable[[], Awaitable[Optional[T]]], callback: Callable[[T], Awaitable[None]]) -> None:
1111
self._num_tasks = num_tasks
1212
self._permits = permits
1313
self._poll = poll
1414
self._callback = callback
1515
self._background_tasks: set[asyncio.Task[None]] = set()
16-
pass
1716

18-
async def run(self):
17+
async def run(self) -> None:
1918
try:
2019
async with asyncio.TaskGroup() as tg:
2120
for i in range(self._num_tasks):
2221
tg.create_task(self._poll_loop())
2322
except asyncio.CancelledError:
24-
pass
23+
pass
2524

2625

27-
async def _poll_loop(self):
26+
async def _poll_loop(self) -> None:
2827
while True:
2928
try:
3029
await self._poll_and_dispatch()
@@ -34,7 +33,7 @@ async def _poll_loop(self):
3433
logger.exception('Exception while polling')
3534

3635

37-
async def _poll_and_dispatch(self):
36+
async def _poll_and_dispatch(self) -> None:
3837
await self._permits.acquire()
3938
try:
4039
task = await self._poll()
@@ -51,7 +50,7 @@ async def _poll_and_dispatch(self):
5150
self._background_tasks.add(scheduled)
5251
scheduled.add_done_callback(self._background_tasks.remove)
5352

54-
async def _execute_callback(self, task: T):
53+
async def _execute_callback(self, task: T) -> None:
5554
try:
5655
await self._callback(task)
5756
except Exception:

cadence/worker/_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class Registry:
3333
workflows and activities in a Cadence application.
3434
"""
3535

36-
def __init__(self):
36+
def __init__(self) -> None:
3737
"""Initialize the registry."""
3838
self._workflows: Dict[str, Callable] = {}
3939
self._activities: Dict[str, Callable] = {}

cadence/worker/_worker.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import uuid
3-
from typing import Unpack
3+
from typing import Unpack, cast
44

55
from cadence.client import Client
66
from cadence.worker._activity import ActivityWorker
@@ -10,7 +10,7 @@
1010

1111
class Worker:
1212

13-
def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOptions]):
13+
def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOptions]) -> None:
1414
self._client = client
1515
self._task_list = task_list
1616

@@ -21,7 +21,7 @@ def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOption
2121
self._decision_worker = DecisionWorker(client, task_list, options)
2222

2323

24-
async def run(self):
24+
async def run(self) -> None:
2525
async with asyncio.TaskGroup() as tg:
2626
if not self._options["disable_workflow_worker"]:
2727
tg.create_task(self._decision_worker.run())
@@ -30,13 +30,13 @@ async def run(self):
3030

3131

3232

33-
def _validate_and_copy_defaults(client: Client, task_list: str, options: WorkerOptions):
33+
def _validate_and_copy_defaults(client: Client, task_list: str, options: WorkerOptions) -> None:
3434
if "identity" not in options:
3535
options["identity"] = f"{client.identity}@{task_list}@{uuid.uuid4()}"
3636

3737
# TODO: More validation
3838

39-
for (key, value) in _DEFAULT_WORKER_OPTIONS.items():
39+
# Set default values for missing options
40+
for key, value in _DEFAULT_WORKER_OPTIONS.items():
4041
if key not in options:
41-
# noinspection PyTypedDict
42-
options[key] = value
42+
cast(dict, options)[key] = value

pyproject.toml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,15 @@ warn_no_return = true
113113
warn_unreachable = true
114114
strict_equality = true
115115
explicit_package_bases = true
116-
disable_error_code = [
117-
"var-annotated",
118-
"arg-type",
119-
"attr-defined",
120-
"assignment",
121-
"literal-required",
122-
]
123116
exclude = [
124-
"cadence/api/*", # Exclude entire api directory with generated proto files
117+
"cadence/api",
118+
"cadence/api/.*",
119+
"cadence/sample",
125120
]
126121

122+
# Reduce recursive module checking
123+
follow_imports = "silent"
124+
127125
[[tool.mypy.overrides]]
128126
module = [
129127
"grpcio.*",
@@ -133,6 +131,7 @@ module = [
133131
"google.protobuf.*",
134132
"uber.cadence.*",
135133
"msgspec.*",
134+
"cadence.api.*",
136135
]
137136
ignore_missing_imports = true
138137

0 commit comments

Comments
 (0)