7
7
from datetime import datetime
8
8
from functools import partial
9
9
import logging
10
- from typing import TYPE_CHECKING , Any , Callable , Deque , Type , TypeVar , cast
10
+ import time
11
+ from typing import TYPE_CHECKING , Any , Callable , Coroutine , Deque , Type , TypeVar , cast
11
12
12
13
from chip .ChipDeviceCtrl import CommissionableNode
13
14
from chip .clusters import Attribute , Objects as Clusters
42
43
DATA_KEY_LAST_NODE_ID = "last_node_id"
43
44
44
45
LOGGER = logging .getLogger (__name__ )
46
+ INTERVIEW_TASK_LIMIT = 10
45
47
46
48
47
49
class MatterDeviceController :
@@ -527,38 +529,75 @@ async def _call_sdk(self, func: Callable[..., _T], *args: Any, **kwargs: Any) ->
527
529
528
530
async def _check_subscriptions_and_interviews (self ) -> None :
529
531
"""Run subscriptions (and interviews) for known nodes."""
532
+ # Set default resubscribe interval to 1 hour
533
+ reschedule_interval = 3600
534
+ start_time = time .time ()
535
+ tasks : list [Coroutine [Any , Any , None ]] = []
536
+ task_limit : asyncio .Semaphore = asyncio .Semaphore (INTERVIEW_TASK_LIMIT )
537
+
530
538
for node_id , node in self ._nodes .items ():
531
539
# (re)interview node (only) if needed
532
540
if (
533
541
node is None
534
542
or node .interview_version < SCHEMA_VERSION
535
543
or (datetime .utcnow () - node .last_interview ).days > 30
536
544
):
545
+
546
+ async def _interview_node (node_id : int ) -> None :
547
+ """Run interview for node."""
548
+ try :
549
+ await self .interview_node (node_id )
550
+ except NodeInterviewFailed as err :
551
+ LOGGER .warning (
552
+ "Unable to interview Node %s, we will retry later in the background." ,
553
+ node_id ,
554
+ exc_info = err ,
555
+ )
556
+ raise err
557
+
558
+ tasks .append (_interview_node (node_id ))
559
+ continue
560
+
561
+ # setup subscriptions for the node
562
+ if node_id in self ._subscriptions :
563
+ continue
564
+
565
+ async def _subscribe_node (node_id : int ) -> None :
566
+ """Subscribe to node events."""
537
567
try :
538
- await self .interview_node (node_id )
539
- except NodeInterviewFailed as err :
568
+ await self .subscribe_node (node_id )
569
+ except NodeNotResolving as err :
540
570
LOGGER .warning (
541
- "Unable to interview Node %s, we will retry later in the background." ,
571
+ "Unable to subscribe to Node %s, "
572
+ "we will retry later in the background." ,
542
573
node_id ,
543
574
exc_info = err ,
544
575
)
545
- continue
576
+ raise err
546
577
547
- # setup subscriptions for the node
548
- if node_id in self ._subscriptions :
549
- continue
550
- try :
551
- await self .subscribe_node (node_id )
552
- except NodeNotResolving as err :
553
- # If the node is unreachable on the network now,
554
- # it will throw a NodeNotResolving exception, catch this,
555
- # log this and just try to resolve this node in the next run.
556
- LOGGER .warning (
557
- "Unable to contact Node %s,"
558
- " we will retry later in the background." ,
559
- node_id ,
560
- exc_info = err ,
561
- )
578
+ tasks .append (_subscribe_node (node_id ))
579
+
580
+ async def _run_task (task : Coroutine [Any , Any , None ]) -> None :
581
+ """Run coroutine and release semaphore."""
582
+ async with task_limit :
583
+ await task
584
+
585
+ LOGGER .debug ("Running %s tasks" , len (tasks ))
586
+ # wait for all tasks to finish
587
+ results : list [Exception | None ] = await asyncio .gather (
588
+ * (_run_task (task ) for task in tasks ), return_exceptions = True
589
+ )
590
+ LOGGER .debug (
591
+ "Done running %s tasks in %s seconds" ,
592
+ len (results ),
593
+ start_time - time .time (),
594
+ )
595
+ # check if any of the tasks failed
596
+ for result in results :
597
+ if isinstance (result , Exception ):
598
+ # if any of the tasks failed, reschedule in 5 minutes
599
+ reschedule_interval = 300
600
+ break
562
601
563
602
# reschedule self to run every hour
564
603
def _schedule () -> None :
@@ -567,8 +606,9 @@ def _schedule() -> None:
567
606
self ._check_subscriptions_and_interviews ()
568
607
)
569
608
609
+ LOGGER .debug ("Rescheduling interviews in %s seconds" , reschedule_interval )
570
610
loop = cast (asyncio .AbstractEventLoop , self .server .loop )
571
- loop .call_later (3600 , _schedule )
611
+ loop .call_later (reschedule_interval , _schedule )
572
612
573
613
@staticmethod
574
614
def _parse_attributes_from_read_result (
0 commit comments