-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathclient_handler.py
243 lines (193 loc) · 8.66 KB
/
client_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
"""Logic to handle a client connected over WebSockets."""
from __future__ import annotations
import asyncio
from concurrent import futures
from contextlib import suppress
import logging
from typing import TYPE_CHECKING, Any, Callable, Final
from aiohttp import WSMsgType, web
import async_timeout
from chip.exceptions import ChipStackError
from matter_server.common.helpers.json import json_dumps, json_loads
from matter_server.common.models import EventType
from ..common.errors import InvalidArguments, InvalidCommand, MatterError, SDKStackError
from ..common.helpers.api import parse_arguments
from ..common.helpers.util import dataclass_from_dict
from ..common.models import (
APICommand,
CommandMessage,
ErrorResultMessage,
EventMessage,
MessageType,
SuccessResultMessage,
)
if TYPE_CHECKING:
from ..common.helpers.api import APICommandHandler
from .server import MatterServer
MAX_PENDING_MSG = 512
CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
LOGGER = logging.getLogger(__name__)
class WebSocketLogAdapter(logging.LoggerAdapter):
"""Add connection id to websocket log messages."""
def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
"""Add connid to websocket log messages."""
assert self.extra is not None
return f"[{self.extra['connid']}] {msg}", kwargs
class WebsocketClientHandler:
"""Handle an active websocket client connection."""
def __init__(self, server: MatterServer, request: web.Request) -> None:
"""Initialize an active connection."""
self.server = server
self.request = request
self.wsock = web.WebSocketResponse(heartbeat=55)
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
self._handle_task: asyncio.Task | None = None
self._writer_task: asyncio.Task | None = None
self._logger = WebSocketLogAdapter(LOGGER, {"connid": id(self)})
self._unsub_callback: Callable | None = None
async def disconnect(self) -> None:
"""Disconnect client."""
self._cancel()
if self._writer_task is not None:
await self._writer_task
async def handle_client(self) -> web.WebSocketResponse:
"""Handle a websocket response."""
# pylint: disable=too-many-branches,too-many-statements
request = self.request
wsock = self.wsock
try:
async with async_timeout.timeout(10):
await wsock.prepare(request)
except asyncio.TimeoutError:
self._logger.warning("Timeout preparing request from %s", request.remote)
return wsock
self._logger.debug("Connected from %s", request.remote)
self._handle_task = asyncio.current_task()
self._writer_task = asyncio.create_task(self._writer())
# send server(version) info when client connects
self._send_message(self.server.get_info())
disconnect_warn = None
try:
while not wsock.closed:
msg = await wsock.receive()
if msg.type in (WSMsgType.CLOSED, WSMsgType.CLOSE, WSMsgType.CLOSING):
break
if msg.type == WSMsgType.ERROR:
disconnect_warn = f"Received error message: {msg.data}"
break
if msg.type != WSMsgType.TEXT:
self._logger.warning("Received non-Text message: %s", msg.data)
continue
self._logger.debug("Received: %s", msg.data)
try:
command_msg = dataclass_from_dict(
CommandMessage, json_loads(msg.data)
)
except ValueError:
disconnect_warn = f"Received invalid JSON: {msg.data}"
break
self._logger.debug("Received %s", command_msg)
self._handle_command(command_msg)
except asyncio.CancelledError:
self._logger.info("Connection closed by client")
except Exception: # pylint: disable=broad-except
self._logger.exception("Unexpected error inside websocket API")
finally:
# Handle connection shutting down.
if self._unsub_callback:
self._logger.debug("Unsubscribed from events")
self._unsub_callback()
try:
self._to_write.put_nowait(None)
# Make sure all error messages are written before closing
await self._writer_task
await wsock.close()
except asyncio.QueueFull: # can be raised by put_nowait
self._writer_task.cancel()
finally:
if disconnect_warn is None:
self._logger.debug("Disconnected")
else:
self._logger.warning("Disconnected: %s", disconnect_warn)
return wsock
def _handle_command(self, msg: CommandMessage) -> None:
"""Handle an incoming command from the client."""
self._logger.debug("Handling command %s", msg.command)
# work out handler for the given path/command
if msg.command == APICommand.START_LISTENING:
self._handle_start_listening_command(msg)
return
handler = self.server.command_handlers.get(msg.command)
if handler is None:
self._send_message(
ErrorResultMessage(
msg.message_id,
InvalidCommand.error_code,
f"Invalid command: {msg.command}",
)
)
self._logger.warning("Invalid command: %s", msg.command)
return
# schedule task to handle the command
asyncio.create_task(self._run_handler(handler, msg))
def _handle_start_listening_command(self, msg: CommandMessage) -> None:
"""Send a full dump of all nodes once and start receiving events."""
assert self._unsub_callback is None, "Listen command already called!"
all_nodes = self.server.device_controller.get_nodes()
self._send_message(SuccessResultMessage(msg.message_id, all_nodes))
def handle_event(evt: EventType, data: Any) -> None:
self._send_message(EventMessage(event=evt, data=data))
self._unsub_callback = self.server.subscribe(handle_event)
async def _run_handler(
self, handler: APICommandHandler, msg: CommandMessage
) -> None:
try:
try:
args = parse_arguments(handler.signature, handler.type_hints, msg.args)
except (TypeError, KeyError, ValueError) as err:
raise InvalidArguments() from err
result = handler.target(**args)
if asyncio.iscoroutine(result):
result = await result
self._send_message(SuccessResultMessage(msg.message_id, result))
except ChipStackError as err:
self._logger.exception("SDK Error during handling message: %s", msg)
self._send_message(
ErrorResultMessage(msg.message_id, SDKStackError.error_code, str(err))
)
except Exception as err: # pylint: disable=broad-except
self._logger.exception("Error handling message: %s", msg)
error_code = getattr(err, "error_code", MatterError.error_code)
self._send_message(ErrorResultMessage(msg.message_id, error_code, str(err)))
async def _writer(self) -> None:
"""Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
if (process := await self._to_write.get()) is None:
break
if not isinstance(process, str):
message: str = process()
else:
message = process
await self.wsock.send_str(message)
def _send_message(self, message: MessageType) -> None:
"""
Send a message to the client.
Closes connection if the client is not reading the messages.
Async friendly.
"""
_message = json_dumps(message)
try:
self._to_write.put_nowait(_message)
except asyncio.QueueFull:
self._logger.error(
"Client exceeded max pending messages: %s", MAX_PENDING_MSG
)
self._cancel()
def _cancel(self) -> None:
"""Cancel the connection."""
if self._handle_task is not None:
self._handle_task.cancel()
if self._writer_task is not None:
self._writer_task.cancel()