1
1
from __future__ import annotations
2
2
3
3
import asyncio
4
+ from concurrent import futures
4
5
from contextlib import suppress
6
+ from functools import partial
5
7
import json
6
8
import typing
7
9
from typing import TYPE_CHECKING , Coroutine , Final
8
- from concurrent import futures
9
10
10
11
from aiohttp import WSCloseCode , WSMsgType , web
11
12
from chip .exceptions import ChipStackError
12
13
14
+ from ..backports .enum import StrEnum
13
15
from ..common .json_utils import CHIPJSONDecoder , CHIPJSONEncoder
14
16
from ..common .model .message import (
15
17
CommandMessage ,
18
20
SuccessResultMessage ,
19
21
)
20
22
from ..common .model .version import VersionInfo
21
- from ..backports .enum import StrEnum
22
23
23
24
if TYPE_CHECKING :
24
- from .server import MatterServer
25
25
from chip .clusters import Attribute , ClusterObjects
26
26
27
+ from .server import MatterServer
28
+
27
29
MAX_PENDING_MSG = 512
28
30
CANCELLATION_ERRORS : Final = (asyncio .CancelledError , futures .CancelledError )
29
31
@@ -42,6 +44,9 @@ class DeviceControllerCommands(StrEnum):
42
44
SET_THREAD_OPERATIONAL_DATASET = "SetThreadOperationalDataset"
43
45
SET_WIFI_CREDENTIALS = "SetWiFiCredentials"
44
46
47
+ # Custom commands
48
+ UNSUBSCRIBE = "Unsubscribe"
49
+
45
50
46
51
PROTOCOL = {
47
52
InstanceCommands .DEVICE_CONTROLLER : {
@@ -178,7 +183,12 @@ def _handle_message(self, msg: CommandMessage):
178
183
self ._send_message (ErrorResultMessage (msg .messageId , "Unknown error" ))
179
184
180
185
def _handle_device_controller_message (self , msg : CommandMessage , command : str ):
181
- if command == "Read" and msg .args .get ("reportInterval" ) is not None :
186
+ if command == DeviceControllerCommands .UNSUBSCRIBE :
187
+ method = partial (self ._handle_device_controller_unsubscribe_message , msg )
188
+ elif (
189
+ command == DeviceControllerCommands .READ
190
+ and msg .args .get ("reportInterval" ) is not None
191
+ ):
182
192
method = self ._handle_device_controller_subscribe_message
183
193
else :
184
194
method = getattr (self .server .stack .device_controller , command )
@@ -307,13 +317,25 @@ def subscription_callback(
307
317
self ._subscriptions [subscription_id ] = subscription
308
318
return subscription_id
309
319
320
+ def _handle_device_controller_unsubscribe_message (
321
+ self , msg : CommandMessage , subscription_id : int
322
+ ):
323
+ """Unsubscribe."""
324
+ subscription = self ._subscriptions .pop (subscription_id , None )
325
+ if subscription is None :
326
+ self ._send_message (ErrorResultMessage (msg .messageId , "not_found" ))
327
+ return
328
+
329
+ subscription .shutdown ()
330
+ self ._send_message (SuccessResultMessage (msg .messageId , None ))
331
+
310
332
async def _handle_coroutine_command (self , msg : CommandMessage , action : Coroutine ):
311
333
try :
312
334
result = await action
313
335
except ChipStackError as ex :
314
336
self ._send_message (ErrorResultMessage (msg .messageId , str (ex )))
315
337
except Exception :
316
338
self .logger .exception ("Error handling message: %s" , msg .data )
317
- self ._send_message (ErrorResultMessage (msg .messageId , "Unknown error " ))
339
+ self ._send_message (ErrorResultMessage (msg .messageId , "unknown_error " ))
318
340
319
341
self ._send_message (SuccessResultMessage (msg .messageId , result ))
0 commit comments