Skip to content

Commit 39c5906

Browse files
authored
Remove executor limit (#621)
1 parent ec4194c commit 39c5906

File tree

1 file changed

+73
-63
lines changed

1 file changed

+73
-63
lines changed

matter_server/server/device_controller.py

+73-63
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import asyncio
88
from collections import deque
9-
from concurrent.futures import ThreadPoolExecutor
109
from datetime import datetime
1110
from functools import partial
1211
import logging
@@ -114,11 +113,9 @@ def __init__(
114113
self._aiozc: AsyncZeroconf | None = None
115114
self._fallback_node_scanner_timer: asyncio.TimerHandle | None = None
116115
self._fallback_node_scanner_task: asyncio.Task | None = None
117-
self._sdk_executor = ThreadPoolExecutor(
118-
max_workers=1, thread_name_prefix="SDKExecutor"
119-
)
120-
self._node_setup_throttle = asyncio.Semaphore(10)
116+
self._node_setup_throttle = asyncio.Semaphore(5)
121117
self._mdns_event_timer: dict[str, asyncio.TimerHandle] = {}
118+
self._node_lock: dict[int, asyncio.Lock] = {}
122119

123120
async def initialize(self) -> None:
124121
"""Async initialize of controller."""
@@ -437,14 +434,15 @@ async def open_commissioning_window(
437434
if discriminator is None:
438435
discriminator = randint(0, 4095) # noqa: S311
439436

440-
sdk_result = await self._call_sdk(
441-
self.chip_controller.OpenCommissioningWindow,
442-
nodeid=node_id,
443-
timeout=timeout,
444-
iteration=iteration,
445-
discriminator=discriminator,
446-
option=option,
447-
)
437+
async with self._get_node_lock(node_id):
438+
sdk_result = await self._call_sdk(
439+
self.chip_controller.OpenCommissioningWindow,
440+
nodeid=node_id,
441+
timeout=timeout,
442+
iteration=iteration,
443+
discriminator=discriminator,
444+
option=option,
445+
)
448446
self._known_commissioning_params[node_id] = params = CommissioningParameters(
449447
setup_pin_code=sdk_result.setupPinCode,
450448
setup_manual_code=sdk_result.setupManualCode,
@@ -504,13 +502,14 @@ async def interview_node(self, node_id: int) -> None:
504502

505503
try:
506504
LOGGER.info("Interviewing node: %s", node_id)
507-
read_response: Attribute.AsyncReadTransaction.ReadResponse = (
508-
await self.chip_controller.Read(
509-
nodeid=node_id,
510-
attributes="*",
511-
fabricFiltered=False,
505+
async with self._get_node_lock(node_id):
506+
read_response: Attribute.AsyncReadTransaction.ReadResponse = (
507+
await self.chip_controller.Read(
508+
nodeid=node_id,
509+
attributes="*",
510+
fabricFiltered=False,
511+
)
512512
)
513-
)
514513
except ChipStackError as err:
515514
raise NodeInterviewFailed(f"Failed to interview node {node_id}") from err
516515

@@ -566,14 +565,15 @@ async def send_device_command(
566565
cluster_cls: Cluster = ALL_CLUSTERS[cluster_id]
567566
command_cls = getattr(cluster_cls.Commands, command_name)
568567
command = dataclass_from_dict(command_cls, payload, allow_sdk_types=True)
569-
return await self.chip_controller.SendCommand(
570-
nodeid=node_id,
571-
endpoint=endpoint_id,
572-
payload=command,
573-
responseType=response_type,
574-
timedRequestTimeoutMs=timed_request_timeout_ms,
575-
interactionTimeoutMs=interaction_timeout_ms,
576-
)
568+
async with self._get_node_lock(node_id):
569+
return await self.chip_controller.SendCommand(
570+
nodeid=node_id,
571+
endpoint=endpoint_id,
572+
payload=command,
573+
responseType=response_type,
574+
timedRequestTimeoutMs=timed_request_timeout_ms,
575+
interactionTimeoutMs=interaction_timeout_ms,
576+
)
577577

578578
@api_command(APICommand.READ_ATTRIBUTE)
579579
async def read_attribute(
@@ -595,21 +595,22 @@ async def read_attribute(
595595

596596
future = self.server.loop.create_future()
597597
device = await self._resolve_node(node_id)
598-
Attribute.Read(
599-
future=future,
600-
eventLoop=self.server.loop,
601-
device=device.deviceProxy,
602-
devCtrl=self.chip_controller,
603-
attributes=[
604-
Attribute.AttributePath(
605-
EndpointId=endpoint_id,
606-
ClusterId=cluster_id,
607-
AttributeId=attribute_id,
608-
)
609-
],
610-
fabricFiltered=fabric_filtered,
611-
).raise_on_error()
612-
result: Attribute.AsyncReadTransaction.ReadResponse = await future
598+
async with self._get_node_lock(node_id):
599+
Attribute.Read(
600+
future=future,
601+
eventLoop=self.server.loop,
602+
device=device.deviceProxy,
603+
devCtrl=self.chip_controller,
604+
attributes=[
605+
Attribute.AttributePath(
606+
EndpointId=endpoint_id,
607+
ClusterId=cluster_id,
608+
AttributeId=attribute_id,
609+
)
610+
],
611+
fabricFiltered=fabric_filtered,
612+
).raise_on_error()
613+
result: Attribute.AsyncReadTransaction.ReadResponse = await future
613614
read_atributes = parse_attributes_from_read_result(result.tlvAttributes)
614615
# update cached info in node attributes
615616
self._nodes[node_id].attributes.update(read_atributes)
@@ -639,10 +640,11 @@ async def write_attribute(
639640
value_type=attribute.attribute_type.Type,
640641
allow_sdk_types=True,
641642
)
642-
return await self.chip_controller.WriteAttribute(
643-
nodeid=node_id,
644-
attributes=[(endpoint_id, attribute)],
645-
)
643+
async with self._get_node_lock(node_id):
644+
return await self.chip_controller.WriteAttribute(
645+
nodeid=node_id,
646+
attributes=[(endpoint_id, attribute)],
647+
)
646648

647649
@api_command(APICommand.REMOVE_NODE)
648650
async def remove_node(self, node_id: int) -> None:
@@ -1001,16 +1003,17 @@ def resubscription_succeeded(
10011003
else:
10021004
interval_ceiling = NODE_SUBSCRIPTION_CEILING_THREAD
10031005
self._last_subscription_attempt[node_id] = 0
1004-
sub: Attribute.SubscriptionTransaction = await self.chip_controller.Read(
1005-
node_id,
1006-
attributes="*",
1007-
events=[("*", 1)],
1008-
returnClusterObject=False,
1009-
reportInterval=(interval_floor, interval_ceiling),
1010-
fabricFiltered=False,
1011-
keepSubscriptions=True,
1012-
autoResubscribe=True,
1013-
)
1006+
async with self._get_node_lock(node_id):
1007+
sub: Attribute.SubscriptionTransaction = await self.chip_controller.Read(
1008+
node_id,
1009+
attributes="*",
1010+
events=[("*", 1)],
1011+
returnClusterObject=False,
1012+
reportInterval=(interval_floor, interval_ceiling),
1013+
fabricFiltered=False,
1014+
keepSubscriptions=True,
1015+
autoResubscribe=True,
1016+
)
10141017

10151018
sub.SetAttributeUpdateCallback(attribute_updated_callback)
10161019
sub.SetEventUpdateCallback(event_callback)
@@ -1050,7 +1053,7 @@ async def _call_sdk(
10501053
return cast(
10511054
_T,
10521055
await self.server.loop.run_in_executor(
1053-
self._sdk_executor,
1056+
None,
10541057
partial(target, *args, **kwargs),
10551058
),
10561059
)
@@ -1136,12 +1139,13 @@ async def _resolve_node(
11361139
retries,
11371140
)
11381141
time_start = time.time()
1139-
return await self._call_sdk(
1140-
self.chip_controller.GetConnectedDeviceSync,
1141-
nodeid=node_id,
1142-
allowPASE=False,
1143-
timeoutMs=None,
1144-
)
1142+
async with self._get_node_lock(node_id):
1143+
return await self._call_sdk(
1144+
self.chip_controller.GetConnectedDeviceSync,
1145+
nodeid=node_id,
1146+
allowPASE=False,
1147+
timeoutMs=None,
1148+
)
11451149
except ChipStackError as err:
11461150
if attempt >= retries:
11471151
# when we're out of retries, raise NodeNotResolving
@@ -1352,3 +1356,9 @@ def run_fallback_node_scanner() -> None:
13521356
self._fallback_node_scanner_timer = self.server.loop.call_later(
13531357
FALLBACK_NODE_SCANNER_INTERVAL, run_fallback_node_scanner
13541358
)
1359+
1360+
def _get_node_lock(self, node_id: int) -> asyncio.Lock:
1361+
"""Return lock for given node."""
1362+
if node_id not in self._node_lock:
1363+
self._node_lock[node_id] = asyncio.Lock()
1364+
return self._node_lock[node_id]

0 commit comments

Comments
 (0)