Skip to content

Commit 2a69b60

Browse files
authored
Optimize gateway transport (#1898)
1 parent 1f50bb1 commit 2a69b60

File tree

4 files changed

+64
-91
lines changed

4 files changed

+64
-91
lines changed

changes/1898.feature.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Optimize gateway transport
2+
- Merge cold path for zlib compression into main path to avoid additional call
3+
- Handle data in `bytes`, rather than in `str` to make good use of speedups (similar to `RESTClient`)

hikari/errors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class GatewayTransportError(GatewayError):
182182
"""An exception thrown if an issue occurs at the transport layer."""
183183

184184
def __str__(self) -> str:
185-
return f"Gateway transport error: {self.reason!r}"
185+
return f"Gateway transport error: {self.reason}"
186186

187187

188188
@attrs.define(auto_exc=True, repr=False, slots=False)

hikari/impl/shard.py

+24-23
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@
117117
_CUSTOM_STATUS_NAME = "Custom Status"
118118

119119

120-
def _log_filterer(token: str) -> typing.Callable[[str], str]:
121-
def filterer(entry: str) -> str:
122-
return entry.replace(token, "**REDACTED TOKEN**")
120+
def _log_filterer(token: bytes) -> typing.Callable[[bytes], bytes]:
121+
def filterer(entry: bytes) -> bytes:
122+
return entry.replace(token, b"**REDACTED TOKEN**")
123123

124124
return filterer
125125

@@ -153,7 +153,7 @@ def __init__(
153153
transport_compression: bool,
154154
exit_stack: contextlib.AsyncExitStack,
155155
logger: logging.Logger,
156-
log_filterer: typing.Callable[[str], str],
156+
log_filterer: typing.Callable[[bytes], bytes],
157157
dumps: data_binding.JSONEncoder,
158158
loads: data_binding.JSONDecoder,
159159
) -> None:
@@ -203,7 +203,7 @@ async def receive_json(self) -> typing.Any:
203203
async def send_json(self, data: data_binding.JSONObject) -> None:
204204
pl = self._dumps(data)
205205
if self._logger.isEnabledFor(ux.TRACE):
206-
filtered = self._log_filterer(pl.decode("utf-8"))
206+
filtered = self._log_filterer(pl)
207207
self._logger.log(ux.TRACE, "sending payload with size %s\n %s", len(pl), filtered)
208208

209209
await self._ws.send_bytes(pl)
@@ -232,39 +232,40 @@ def _handle_other_message(self, message: aiohttp.WSMessage, /) -> typing.NoRetur
232232
reason = f"{message.data!r} [extra={message.extra!r}, type={message.type}]"
233233
raise errors.GatewayTransportError(reason) from self._ws.exception()
234234

235-
async def _receive_and_check_text(self) -> str:
235+
async def _receive_and_check_text(self) -> bytes:
236236
message = await self._ws.receive()
237237

238238
if message.type == aiohttp.WSMsgType.TEXT:
239239
assert isinstance(message.data, str)
240-
return message.data
240+
return message.data.encode()
241241

242242
self._handle_other_message(message)
243243

244-
async def _receive_and_check_zlib(self) -> str:
244+
async def _receive_and_check_zlib(self) -> bytes:
245245
message = await self._ws.receive()
246246

247247
if message.type == aiohttp.WSMsgType.BINARY:
248248
if message.data.endswith(_ZLIB_SUFFIX):
249-
return self._zlib.decompress(message.data).decode("utf-8")
250-
251-
return await self._receive_and_check_complete_zlib_package(message.data)
249+
# Hot and fast path: we already have the full message
250+
# in a single frame
251+
return self._zlib.decompress(message.data)
252252

253-
self._handle_other_message(message)
253+
# Cold and slow path: we need to keep receiving frames to complete
254+
# the whole message. Only then do we create a buffer
255+
buff = bytearray(message.data)
254256

255-
async def _receive_and_check_complete_zlib_package(self, initial_data: bytes, /) -> str:
256-
buff = bytearray(initial_data)
257+
while not buff.endswith(_ZLIB_SUFFIX):
258+
message = await self._ws.receive()
257259

258-
while not buff.endswith(_ZLIB_SUFFIX):
259-
message = await self._ws.receive()
260+
if message.type == aiohttp.WSMsgType.BINARY:
261+
buff.extend(message.data)
262+
continue
260263

261-
if message.type == aiohttp.WSMsgType.BINARY:
262-
buff.extend(message.data)
263-
continue
264+
self._handle_other_message(message)
264265

265-
self._handle_other_message(message)
266+
return self._zlib.decompress(buff)
266267

267-
return self._zlib.decompress(buff).decode("utf-8")
268+
self._handle_other_message(message)
268269

269270
@classmethod
270271
async def connect(
@@ -273,7 +274,7 @@ async def connect(
273274
http_settings: config.HTTPSettings,
274275
logger: logging.Logger,
275276
proxy_settings: config.ProxySettings,
276-
log_filterer: typing.Callable[[str], str],
277+
log_filterer: typing.Callable[[bytes], bytes],
277278
dumps: data_binding.JSONEncoder,
278279
loads: data_binding.JSONDecoder,
279280
transport_compression: bool,
@@ -810,7 +811,7 @@ async def _connect(self) -> typing.Tuple[asyncio.Task[None], ...]:
810811

811812
self._ws = await _GatewayTransport.connect(
812813
http_settings=self._http_settings,
813-
log_filterer=_log_filterer(self._token),
814+
log_filterer=_log_filterer(self._token.encode()),
814815
logger=self._logger,
815816
proxy_settings=self._proxy_settings,
816817
transport_compression=self._transport_compression,

tests/hikari/impl/test_shard.py

+36-67
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@
4444

4545

4646
def test_log_filterer():
47-
filterer = shard._log_filterer("TOKEN")
47+
filterer = shard._log_filterer(b"TOKEN")
4848

49-
returned = filterer("this log contains the TOKEN and it should get removed and the TOKEN here too")
49+
returned = filterer(b"this log contains the TOKEN and it should get removed and the TOKEN here too")
5050
assert returned == (
51-
"this log contains the **REDACTED TOKEN** and it should get removed and the **REDACTED TOKEN** here too"
51+
b"this log contains the **REDACTED TOKEN** and it should get removed and the **REDACTED TOKEN** here too"
5252
)
5353

5454

@@ -275,100 +275,69 @@ def test__handle_other_message_when_message_type_is_unknown(self, transport_impl
275275
assert exc_info.value.__cause__ is exception
276276

277277
@pytest.mark.asyncio
278-
async def test__receive_and_check_text_when_message_type_is_TEXT(self, transport_impl):
278+
async def test__receive_and_check_text(self, transport_impl):
279279
transport_impl._ws.receive = mock.AsyncMock(
280280
return_value=StubResponse(type=aiohttp.WSMsgType.TEXT, data="some text")
281281
)
282282

283-
assert await transport_impl._receive_and_check_text() == "some text"
283+
assert await transport_impl._receive_and_check_text() == b"some text"
284284

285285
transport_impl._ws.receive.assert_awaited_once_with()
286286

287287
@pytest.mark.asyncio
288288
async def test__receive_and_check_text_when_message_type_is_unknown(self, transport_impl):
289-
mock_exception = errors.GatewayError("aye")
290289
transport_impl._ws.receive = mock.AsyncMock(return_value=StubResponse(type=aiohttp.WSMsgType.BINARY))
291290

292-
with mock.patch.object(
293-
shard._GatewayTransport, "_handle_other_message", side_effect=mock_exception
294-
) as handle_other_message:
295-
with pytest.raises(errors.GatewayError) as exc_info:
296-
await transport_impl._receive_and_check_text()
291+
with pytest.raises(
292+
errors.GatewayTransportError,
293+
match="Gateway transport error: Unexpected message type received BINARY, expected TEXT",
294+
):
295+
await transport_impl._receive_and_check_text()
297296

298-
assert exc_info.value is mock_exception
299297
transport_impl._ws.receive.assert_awaited_once_with()
300-
handle_other_message.assert_called_once_with(transport_impl._ws.receive.return_value)
301298

302299
@pytest.mark.asyncio
303-
async def test__receive_and_check_zlib_when_message_type_is_BINARY(self, transport_impl):
304-
response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some initial data")
305-
transport_impl._ws.receive = mock.AsyncMock(return_value=response)
300+
async def test__receive_and_check_zlib_when_payload_split_across_frames(self, transport_impl):
301+
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xda\xf2H\xcd\xc9")
302+
response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\xc9W(\xcf/\xcaIQ\x04\x00\x00")
303+
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\xff\xff")
304+
transport_impl._ws.receive = mock.AsyncMock(side_effect=[response1, response2, response3])
306305

307-
with mock.patch.object(
308-
shard._GatewayTransport, "_receive_and_check_complete_zlib_package"
309-
) as receive_and_check_complete_zlib_package:
310-
assert (
311-
await transport_impl._receive_and_check_zlib() is receive_and_check_complete_zlib_package.return_value
312-
)
306+
assert await transport_impl._receive_and_check_zlib() == b"Hello world!"
313307

314-
transport_impl._ws.receive.assert_awaited_once_with()
315-
receive_and_check_complete_zlib_package.assert_awaited_once_with(b"some initial data")
308+
assert transport_impl._ws.receive.call_count == 3
316309

317310
@pytest.mark.asyncio
318-
async def test__receive_and_check_zlib_when_message_type_is_BINARY_and_the_full_payload(self, transport_impl):
319-
response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some initial data\x00\x00\xff\xff")
311+
async def test__receive_and_check_zlib_when_full_payload_in_one_frame(self, transport_impl):
312+
response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xdaJLD\x07\x00\x00\x00\x00\xff\xff")
320313
transport_impl._ws.receive = mock.AsyncMock(return_value=response)
321-
transport_impl._zlib = mock.Mock(decompress=mock.Mock(return_value=b"aaaaaaaaaaaaaaaaaa"))
322314

323-
assert await transport_impl._receive_and_check_zlib() == "aaaaaaaaaaaaaaaaaa"
315+
assert await transport_impl._receive_and_check_zlib() == b"aaaaaaaaaaaaaaaaaa"
324316

325317
transport_impl._ws.receive.assert_awaited_once_with()
326-
transport_impl._zlib.decompress.assert_called_once_with(response.data)
327318

328319
@pytest.mark.asyncio
329320
async def test__receive_and_check_zlib_when_message_type_is_unknown(self, transport_impl):
330-
mock_exception = errors.GatewayError("aye")
331321
transport_impl._ws.receive = mock.AsyncMock(return_value=StubResponse(type=aiohttp.WSMsgType.TEXT))
332322

333-
with mock.patch.object(
334-
shard._GatewayTransport, "_handle_other_message", side_effect=mock_exception
335-
) as handle_other_message:
336-
with pytest.raises(errors.GatewayError) as exc_info:
337-
await transport_impl._receive_and_check_zlib()
338-
339-
assert exc_info.value is mock_exception
340-
transport_impl._ws.receive.assert_awaited_once_with()
341-
handle_other_message.assert_called_once_with(transport_impl._ws.receive.return_value)
342-
343-
@pytest.mark.asyncio
344-
async def test__receive_and_check_complete_zlib_package_for_unexpected_message_type(self, transport_impl):
345-
mock_exception = errors.GatewayError("aye")
346-
response = StubResponse(type=aiohttp.WSMsgType.TEXT)
347-
transport_impl._ws.receive = mock.AsyncMock(return_value=response)
348-
349-
with mock.patch.object(
350-
shard._GatewayTransport, "_handle_other_message", side_effect=mock_exception
351-
) as handle_other_message:
352-
with pytest.raises(errors.GatewayError) as exc_info:
353-
await transport_impl._receive_and_check_complete_zlib_package(b"some")
354-
355-
assert exc_info.value is mock_exception
356-
transport_impl._ws.receive.assert_awaited_with()
357-
handle_other_message.assert_called_once_with(response)
323+
with pytest.raises(
324+
errors.GatewayTransportError,
325+
match="Gateway transport error: Unexpected message type received TEXT, expected BINARY",
326+
):
327+
await transport_impl._receive_and_check_zlib()
358328

359329
@pytest.mark.asyncio
360-
async def test__receive_and_check_complete_zlib_package(self, transport_impl):
361-
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"more")
362-
response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"data")
363-
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\x00\xff\xff")
330+
async def test__receive_and_check_zlib_when_issue_during_reception_of_multiple_frames(self, transport_impl):
331+
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xda\xf2H\xcd\xc9")
332+
response2 = StubResponse(type=aiohttp.WSMsgType.ERROR, data="Something broke!")
333+
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\xff\xff")
364334
transport_impl._ws.receive = mock.AsyncMock(side_effect=[response1, response2, response3])
365-
transport_impl._zlib = mock.Mock(decompress=mock.Mock(return_value=b"decoded utf-8 encoded bytes"))
366-
367-
assert await transport_impl._receive_and_check_complete_zlib_package(b"some") == "decoded utf-8 encoded bytes"
335+
transport_impl._ws.exception = mock.Mock(return_value=None)
368336

369-
assert transport_impl._ws.receive.call_count == 3
370-
transport_impl._ws.receive.assert_has_awaits([mock.call(), mock.call(), mock.call()])
371-
transport_impl._zlib.decompress.assert_called_once_with(bytearray(b"somemoredata\x00\x00\xff\xff"))
337+
with pytest.raises(
338+
errors.GatewayTransportError, match=r"Gateway transport error: 'Something broke!' \[extra=None, type=258\]"
339+
):
340+
await transport_impl._receive_and_check_zlib()
372341

373342
@pytest.mark.parametrize("transport_compression", [True, False])
374343
@pytest.mark.asyncio
@@ -1002,7 +971,7 @@ async def test__connect_when_not_reconnecting(self, client, http_settings, proxy
1002971
with stack:
1003972
assert await client._connect() == (heartbeat_task, poll_events_task)
1004973

1005-
log_filterer.assert_called_once_with("sometoken")
974+
log_filterer.assert_called_once_with(b"sometoken")
1006975
gateway_transport_connect.assert_called_once_with(
1007976
http_settings=http_settings,
1008977
log_filterer=log_filterer.return_value,
@@ -1087,7 +1056,7 @@ async def test__connect_when_reconnecting(self, client, http_settings, proxy_set
10871056
with stack:
10881057
assert await client._connect() == (heartbeat_task, poll_events_task)
10891058

1090-
log_filterer.assert_called_once_with("sometoken")
1059+
log_filterer.assert_called_once_with(b"sometoken")
10911060
gateway_transport_connect.assert_called_once_with(
10921061
http_settings=http_settings,
10931062
log_filterer=log_filterer.return_value,

0 commit comments

Comments
 (0)