Skip to content

Commit 9c6c543

Browse files
authored
Add ice servers to async WebRTC offer (#4)
* Add ice servers to async WebRTC offer * Update to reflect upstream changes * Fix code and test * Guard against unexpected messages * Fix typing * Increase coverage
1 parent 140d2cb commit 9c6c543

File tree

6 files changed

+141
-31
lines changed

6 files changed

+141
-31
lines changed

.pre-commit-config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ repos:
77
name: 🐶 Ruff lint
88
args:
99
- --fix
10+
# - --unsafe-fixes
1011

1112
- id: ruff-format
1213
name: 🐶 Ruff format

go2rtc_client/ws/client.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from go2rtc_client.exceptions import handle_error
1212

13-
from .messages import BaseMessage
13+
from .messages import BaseMessage, ReceiveMessages, SendMessages, WebRTC, WsMessage
1414

1515
_LOGGER = logging.getLogger(__name__)
1616

@@ -43,7 +43,7 @@ def __init__(
4343
self._params = params
4444
self._client: ClientWebSocketResponse | None = None
4545
self._rx_task: asyncio.Task[None] | None = None
46-
self._subscribers: list[Callable[[BaseMessage], None]] = []
46+
self._subscribers: list[Callable[[ReceiveMessages], None]] = []
4747
self._connect_lock = asyncio.Lock()
4848

4949
@property
@@ -77,7 +77,7 @@ async def close(self) -> None:
7777
await client.close()
7878

7979
@handle_error
80-
async def send(self, message: BaseMessage) -> None:
80+
async def send(self, message: SendMessages) -> None:
8181
"""Send a message."""
8282
if not self.connected:
8383
await self.connect()
@@ -90,10 +90,15 @@ async def send(self, message: BaseMessage) -> None:
9090
def _process_text_message(self, data: Any) -> None:
9191
"""Process text message."""
9292
try:
93-
message = BaseMessage.from_json(data)
93+
message: WsMessage = BaseMessage.from_json(data)
9494
except Exception: # pylint: disable=broad-except
9595
_LOGGER.exception("Invalid message received: %s", data)
9696
else:
97+
if isinstance(message, WebRTC):
98+
message = message.value
99+
if not isinstance(message, ReceiveMessages):
100+
_LOGGER.error("Received unexpected message: %s", message)
101+
return
97102
for subscriber in self._subscribers:
98103
try:
99104
subscriber(message)
@@ -134,7 +139,9 @@ async def _receive_messages(self) -> None:
134139
if self.connected:
135140
await self.close()
136141

137-
def subscribe(self, callback: Callable[[BaseMessage], None]) -> Callable[[], None]:
142+
def subscribe(
143+
self, callback: Callable[[ReceiveMessages], None]
144+
) -> Callable[[], None]:
138145
"""Subscribe to messages."""
139146

140147
def _unsubscribe() -> None:

go2rtc_client/ws/messages.py

+59-15
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,34 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass, field
6-
from typing import Any, ClassVar
6+
from typing import Annotated, Any, ClassVar
77

88
from mashumaro import field_options
99
from mashumaro.config import BaseConfig
1010
from mashumaro.mixins.orjson import DataClassORJSONMixin
1111
from mashumaro.types import Discriminator
12+
from webrtc_models import (
13+
RTCIceServer, # noqa: TCH002 # Mashumaro needs the import to generate the correct code
14+
)
1215

1316

1417
@dataclass(frozen=True)
15-
class BaseMessage(DataClassORJSONMixin):
16-
"""Base message class."""
18+
class WsMessage:
19+
"""Websocket message."""
1720

1821
TYPE: ClassVar[str]
1922

23+
def __post_serialize__(self, d: dict[Any, Any]) -> dict[Any, Any]:
24+
"""Add type to serialized dict."""
25+
# ClassVar will not serialize by default
26+
d["type"] = self.TYPE
27+
return d
28+
29+
30+
@dataclass(frozen=True)
31+
class BaseMessage(WsMessage, DataClassORJSONMixin):
32+
"""Base message class."""
33+
2034
class Config(BaseConfig):
2135
"""Config for BaseMessage."""
2236

@@ -27,12 +41,6 @@ class Config(BaseConfig):
2741
variant_tagger_fn=lambda cls: cls.TYPE,
2842
)
2943

30-
def __post_serialize__(self, d: dict[Any, Any]) -> dict[Any, Any]:
31-
"""Add type to serialized dict."""
32-
# ClassVar will not serialize by default
33-
d["type"] = self.TYPE
34-
return d
35-
3644

3745
@dataclass(frozen=True)
3846
class WebRTCCandidate(BaseMessage):
@@ -43,19 +51,55 @@ class WebRTCCandidate(BaseMessage):
4351

4452

4553
@dataclass(frozen=True)
46-
class WebRTCOffer(BaseMessage):
54+
class WebRTC(BaseMessage):
55+
"""WebRTC message."""
56+
57+
TYPE = "webrtc"
58+
value: Annotated[
59+
WebRTCOffer | WebRTCValue,
60+
Discriminator(
61+
field="type",
62+
include_subtypes=True,
63+
variant_tagger_fn=lambda cls: cls.TYPE,
64+
),
65+
]
66+
67+
68+
@dataclass(frozen=True)
69+
class WebRTCValue(WsMessage):
70+
"""WebRTC value for WebRTC message."""
71+
72+
sdp: str
73+
74+
75+
@dataclass(frozen=True)
76+
class WebRTCOffer(WebRTCValue):
4777
"""WebRTC offer message."""
4878

49-
TYPE = "webrtc/offer"
50-
offer: str = field(metadata=field_options(alias="value"))
79+
TYPE = "offer"
80+
ice_servers: list[RTCIceServer]
81+
82+
def __pre_serialize__(self) -> WebRTCOffer:
83+
"""Pre serialize.
84+
85+
Go2rtc supports only ice_servers with urls as list of strings.
86+
"""
87+
for server in self.ice_servers:
88+
if isinstance(server.urls, str):
89+
server.urls = [server.urls]
90+
91+
return self
92+
93+
def to_json(self, **kwargs: Any) -> str:
94+
"""Convert to json."""
95+
return WebRTC(self).to_json(**kwargs)
5196

5297

5398
@dataclass(frozen=True)
54-
class WebRTCAnswer(BaseMessage):
99+
class WebRTCAnswer(WebRTCValue):
55100
"""WebRTC answer message."""
56101

57-
TYPE = "webrtc/answer"
58-
answer: str = field(metadata=field_options(alias="value"))
102+
TYPE = "answer"
59103

60104

61105
@dataclass(frozen=True)

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies = [
2222
"awesomeversion>=24.6.0",
2323
"mashumaro~=3.13",
2424
"orjson>=3.10.7",
25+
"webrtc-models>=0.1.0",
2526
]
2627
version = "0.0.0"
2728

tests/ws/test_client.py

+53-11
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,18 @@
1717
from aiohttp.web import WebSocketResponse
1818
from multidict import CIMultiDict, CIMultiDictProxy
1919
import pytest
20+
from webrtc_models import RTCIceServer
2021
from yarl import URL
2122

2223
from go2rtc_client.exceptions import Go2RtcClientError
23-
from go2rtc_client.ws.client import Go2RtcWsClient
24-
from go2rtc_client.ws.messages import BaseMessage, WebRTCAnswer, WebRTCCandidate
24+
from go2rtc_client.ws import (
25+
Go2RtcWsClient,
26+
ReceiveMessages,
27+
SendMessages,
28+
WebRTCAnswer,
29+
WebRTCCandidate,
30+
WebRTCOffer,
31+
)
2532

2633

2734
class TestServer:
@@ -111,7 +118,27 @@ async def test_connect_parallel(server: TestServer) -> None:
111118
assert client.connected
112119

113120

114-
async def test_send(ws_client: Go2RtcWsClient, server: TestServer) -> None:
121+
@pytest.mark.parametrize(
122+
("message", "expected"),
123+
[
124+
(WebRTCCandidate("test"), '{"value":"test","type":"webrtc/candidate"}'),
125+
(
126+
WebRTCOffer("test", []),
127+
'{"value":{"sdp":"test","ice_servers":[],"type":"offer"},"type":"webrtc"}',
128+
),
129+
(
130+
WebRTCOffer("test", [RTCIceServer("url")]),
131+
'{"value":{"sdp":"test","ice_servers":[{"urls":["url"]}],"type":"offer"},"type":"webrtc"}',
132+
),
133+
(
134+
WebRTCOffer("test", [RTCIceServer(["url1", "url2"])]),
135+
'{"value":{"sdp":"test","ice_servers":[{"urls":["url1","url2"]}],"type":"offer"},"type":"webrtc"}',
136+
),
137+
],
138+
)
139+
async def test_send(
140+
ws_client: Go2RtcWsClient, server: TestServer, message: SendMessages, expected: str
141+
) -> None:
115142
"""Test sending a message through the WebSocket."""
116143
received_message = None
117144

@@ -121,28 +148,31 @@ def on_message(msg: WSMessage) -> None:
121148

122149
server.on_message = on_message
123150

124-
await ws_client.send(WebRTCCandidate("test"))
151+
await ws_client.send(message)
125152
await asyncio.sleep(0.1)
126-
assert received_message == '{"value":"test","type":"webrtc/candidate"}'
153+
assert received_message == expected
127154

128155

129156
@pytest.mark.parametrize(
130157
("message", "expected"),
131158
[
132159
('{"value":"test","type":"webrtc/candidate"}', WebRTCCandidate("test")),
133-
('{"value":"test","type":"webrtc/answer"}', WebRTCAnswer("test")),
160+
(
161+
'{"value":{"type":"answer", "sdp":"test"},"type":"webrtc"}',
162+
WebRTCAnswer("test"),
163+
),
134164
],
135165
)
136166
async def test_receive(
137167
ws_client_connected: Go2RtcWsClient,
138168
server: TestServer,
139169
message: str,
140-
expected: BaseMessage,
170+
expected: ReceiveMessages,
141171
) -> None:
142172
"""Test receiving a message through the WebSocket."""
143173
received_message = None
144174

145-
def on_message(message: BaseMessage) -> None:
175+
def on_message(message: ReceiveMessages) -> None:
146176
nonlocal received_message
147177
received_message = message
148178

@@ -230,7 +260,7 @@ async def test_subscribe_unsubscribe(ws_client: Go2RtcWsClient) -> None:
230260
# pylint: disable=protected-access
231261
assert ws_client._subscribers == []
232262

233-
def on_message(_: BaseMessage) -> None:
263+
def on_message(_: ReceiveMessages) -> None:
234264
pass
235265

236266
unsub = ws_client.subscribe(on_message)
@@ -249,14 +279,14 @@ async def test_subscriber_raised(
249279
) -> None:
250280
"""Test any exception raised by any subscriber will be handled."""
251281

252-
def on_message_raise(_: BaseMessage) -> None:
282+
def on_message_raise(_: ReceiveMessages) -> None:
253283
raise ValueError
254284

255285
ws_client_connected.subscribe(on_message_raise)
256286

257287
received_message = None
258288

259-
def on_message(message: BaseMessage) -> None:
289+
def on_message(message: ReceiveMessages) -> None:
260290
nonlocal received_message
261291
received_message = message
262292

@@ -294,6 +324,18 @@ def on_message(message: BaseMessage) -> None:
294324
WSMessage(WSMsgType.ERROR, "error", None),
295325
("go2rtc_client.ws.client", logging.ERROR, "Error received: error"),
296326
),
327+
(
328+
WSMessage(
329+
WSMsgType.TEXT,
330+
'{"value":{"sdp":"test","ice_servers":[],"type":"offer"},"type":"webrtc"}',
331+
None,
332+
),
333+
(
334+
"go2rtc_client.ws.client",
335+
logging.ERROR,
336+
"Received unexpected message: WebRTCOffer(sdp='test', ice_servers=[])",
337+
),
338+
),
297339
],
298340
)
299341
async def test_unexpected_messages(

uv.lock

+15
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)