6
6
7
7
import asyncio
8
8
from collections import deque
9
- from concurrent .futures import ThreadPoolExecutor
10
9
from datetime import datetime
11
10
from functools import partial
12
11
import logging
@@ -114,11 +113,9 @@ def __init__(
114
113
self ._aiozc : AsyncZeroconf | None = None
115
114
self ._fallback_node_scanner_timer : asyncio .TimerHandle | None = None
116
115
self ._fallback_node_scanner_task : asyncio .Task | None = None
117
- self ._sdk_executor = ThreadPoolExecutor (
118
- max_workers = 1 , thread_name_prefix = "SDKExecutor"
119
- )
120
- self ._node_setup_throttle = asyncio .Semaphore (10 )
116
+ self ._node_setup_throttle = asyncio .Semaphore (5 )
121
117
self ._mdns_event_timer : dict [str , asyncio .TimerHandle ] = {}
118
+ self ._node_lock : dict [int , asyncio .Lock ] = {}
122
119
123
120
async def initialize (self ) -> None :
124
121
"""Async initialize of controller."""
@@ -437,14 +434,15 @@ async def open_commissioning_window(
437
434
if discriminator is None :
438
435
discriminator = randint (0 , 4095 ) # noqa: S311
439
436
440
- sdk_result = await self ._call_sdk (
441
- self .chip_controller .OpenCommissioningWindow ,
442
- nodeid = node_id ,
443
- timeout = timeout ,
444
- iteration = iteration ,
445
- discriminator = discriminator ,
446
- option = option ,
447
- )
437
+ async with self ._get_node_lock (node_id ):
438
+ sdk_result = await self ._call_sdk (
439
+ self .chip_controller .OpenCommissioningWindow ,
440
+ nodeid = node_id ,
441
+ timeout = timeout ,
442
+ iteration = iteration ,
443
+ discriminator = discriminator ,
444
+ option = option ,
445
+ )
448
446
self ._known_commissioning_params [node_id ] = params = CommissioningParameters (
449
447
setup_pin_code = sdk_result .setupPinCode ,
450
448
setup_manual_code = sdk_result .setupManualCode ,
@@ -504,13 +502,14 @@ async def interview_node(self, node_id: int) -> None:
504
502
505
503
try :
506
504
LOGGER .info ("Interviewing node: %s" , node_id )
507
- read_response : Attribute .AsyncReadTransaction .ReadResponse = (
508
- await self .chip_controller .Read (
509
- nodeid = node_id ,
510
- attributes = "*" ,
511
- fabricFiltered = False ,
505
+ async with self ._get_node_lock (node_id ):
506
+ read_response : Attribute .AsyncReadTransaction .ReadResponse = (
507
+ await self .chip_controller .Read (
508
+ nodeid = node_id ,
509
+ attributes = "*" ,
510
+ fabricFiltered = False ,
511
+ )
512
512
)
513
- )
514
513
except ChipStackError as err :
515
514
raise NodeInterviewFailed (f"Failed to interview node { node_id } " ) from err
516
515
@@ -566,14 +565,15 @@ async def send_device_command(
566
565
cluster_cls : Cluster = ALL_CLUSTERS [cluster_id ]
567
566
command_cls = getattr (cluster_cls .Commands , command_name )
568
567
command = dataclass_from_dict (command_cls , payload , allow_sdk_types = True )
569
- return await self .chip_controller .SendCommand (
570
- nodeid = node_id ,
571
- endpoint = endpoint_id ,
572
- payload = command ,
573
- responseType = response_type ,
574
- timedRequestTimeoutMs = timed_request_timeout_ms ,
575
- interactionTimeoutMs = interaction_timeout_ms ,
576
- )
568
+ async with self ._get_node_lock (node_id ):
569
+ return await self .chip_controller .SendCommand (
570
+ nodeid = node_id ,
571
+ endpoint = endpoint_id ,
572
+ payload = command ,
573
+ responseType = response_type ,
574
+ timedRequestTimeoutMs = timed_request_timeout_ms ,
575
+ interactionTimeoutMs = interaction_timeout_ms ,
576
+ )
577
577
578
578
@api_command (APICommand .READ_ATTRIBUTE )
579
579
async def read_attribute (
@@ -595,21 +595,22 @@ async def read_attribute(
595
595
596
596
future = self .server .loop .create_future ()
597
597
device = await self ._resolve_node (node_id )
598
- Attribute .Read (
599
- future = future ,
600
- eventLoop = self .server .loop ,
601
- device = device .deviceProxy ,
602
- devCtrl = self .chip_controller ,
603
- attributes = [
604
- Attribute .AttributePath (
605
- EndpointId = endpoint_id ,
606
- ClusterId = cluster_id ,
607
- AttributeId = attribute_id ,
608
- )
609
- ],
610
- fabricFiltered = fabric_filtered ,
611
- ).raise_on_error ()
612
- result : Attribute .AsyncReadTransaction .ReadResponse = await future
598
+ async with self ._get_node_lock (node_id ):
599
+ Attribute .Read (
600
+ future = future ,
601
+ eventLoop = self .server .loop ,
602
+ device = device .deviceProxy ,
603
+ devCtrl = self .chip_controller ,
604
+ attributes = [
605
+ Attribute .AttributePath (
606
+ EndpointId = endpoint_id ,
607
+ ClusterId = cluster_id ,
608
+ AttributeId = attribute_id ,
609
+ )
610
+ ],
611
+ fabricFiltered = fabric_filtered ,
612
+ ).raise_on_error ()
613
+ result : Attribute .AsyncReadTransaction .ReadResponse = await future
613
614
read_atributes = parse_attributes_from_read_result (result .tlvAttributes )
614
615
# update cached info in node attributes
615
616
self ._nodes [node_id ].attributes .update (read_atributes )
@@ -639,10 +640,11 @@ async def write_attribute(
639
640
value_type = attribute .attribute_type .Type ,
640
641
allow_sdk_types = True ,
641
642
)
642
- return await self .chip_controller .WriteAttribute (
643
- nodeid = node_id ,
644
- attributes = [(endpoint_id , attribute )],
645
- )
643
+ async with self ._get_node_lock (node_id ):
644
+ return await self .chip_controller .WriteAttribute (
645
+ nodeid = node_id ,
646
+ attributes = [(endpoint_id , attribute )],
647
+ )
646
648
647
649
@api_command (APICommand .REMOVE_NODE )
648
650
async def remove_node (self , node_id : int ) -> None :
@@ -1001,16 +1003,17 @@ def resubscription_succeeded(
1001
1003
else :
1002
1004
interval_ceiling = NODE_SUBSCRIPTION_CEILING_THREAD
1003
1005
self ._last_subscription_attempt [node_id ] = 0
1004
- sub : Attribute .SubscriptionTransaction = await self .chip_controller .Read (
1005
- node_id ,
1006
- attributes = "*" ,
1007
- events = [("*" , 1 )],
1008
- returnClusterObject = False ,
1009
- reportInterval = (interval_floor , interval_ceiling ),
1010
- fabricFiltered = False ,
1011
- keepSubscriptions = True ,
1012
- autoResubscribe = True ,
1013
- )
1006
+ async with self ._get_node_lock (node_id ):
1007
+ sub : Attribute .SubscriptionTransaction = await self .chip_controller .Read (
1008
+ node_id ,
1009
+ attributes = "*" ,
1010
+ events = [("*" , 1 )],
1011
+ returnClusterObject = False ,
1012
+ reportInterval = (interval_floor , interval_ceiling ),
1013
+ fabricFiltered = False ,
1014
+ keepSubscriptions = True ,
1015
+ autoResubscribe = True ,
1016
+ )
1014
1017
1015
1018
sub .SetAttributeUpdateCallback (attribute_updated_callback )
1016
1019
sub .SetEventUpdateCallback (event_callback )
@@ -1050,7 +1053,7 @@ async def _call_sdk(
1050
1053
return cast (
1051
1054
_T ,
1052
1055
await self .server .loop .run_in_executor (
1053
- self . _sdk_executor ,
1056
+ None ,
1054
1057
partial (target , * args , ** kwargs ),
1055
1058
),
1056
1059
)
@@ -1136,12 +1139,13 @@ async def _resolve_node(
1136
1139
retries ,
1137
1140
)
1138
1141
time_start = time .time ()
1139
- return await self ._call_sdk (
1140
- self .chip_controller .GetConnectedDeviceSync ,
1141
- nodeid = node_id ,
1142
- allowPASE = False ,
1143
- timeoutMs = None ,
1144
- )
1142
+ async with self ._get_node_lock (node_id ):
1143
+ return await self ._call_sdk (
1144
+ self .chip_controller .GetConnectedDeviceSync ,
1145
+ nodeid = node_id ,
1146
+ allowPASE = False ,
1147
+ timeoutMs = None ,
1148
+ )
1145
1149
except ChipStackError as err :
1146
1150
if attempt >= retries :
1147
1151
# when we're out of retries, raise NodeNotResolving
@@ -1352,3 +1356,9 @@ def run_fallback_node_scanner() -> None:
1352
1356
self ._fallback_node_scanner_timer = self .server .loop .call_later (
1353
1357
FALLBACK_NODE_SCANNER_INTERVAL , run_fallback_node_scanner
1354
1358
)
1359
+
1360
+ def _get_node_lock (self , node_id : int ) -> asyncio .Lock :
1361
+ """Return lock for given node."""
1362
+ if node_id not in self ._node_lock :
1363
+ self ._node_lock [node_id ] = asyncio .Lock ()
1364
+ return self ._node_lock [node_id ]
0 commit comments