Skip to content

Commit e492a19

Browse files
Enforce WebSocket message size limit at reception
1 parent ed5cd14 commit e492a19

File tree

3 files changed

+62
-28
lines changed

3 files changed

+62
-28
lines changed

src/impl/websocket.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,7 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
443443
}
444444
};
445445

446-
auto maxOutstandingPings = config.maxOutstandingPings.value_or(0);
447-
auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, maxOutstandingPings,
446+
auto transport = std::make_shared<WsTransport>(lower, mWsHandshake, config,
448447
weak_bind(&WebSocket::incoming, this, _1),
449448
stateChangeCallback);
450449

src/impl/wstransport.cpp

+51-20
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,18 @@ using std::to_integer;
4343
using std::to_string;
4444
using std::chrono::system_clock;
4545

46-
WsTransport::WsTransport(
47-
variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>
48-
lower,
49-
shared_ptr<WsHandshake> handshake, int maxOutstandingPings, message_callback recvCallback,
50-
state_callback stateCallback)
46+
WsTransport::WsTransport(LowerTransport lower, shared_ptr<WsHandshake> handshake,
47+
const WebSocketConfiguration &config, message_callback recvCallback,
48+
state_callback stateCallback)
5149
: Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
5250
std::move(stateCallback)),
5351
mHandshake(std::move(handshake)),
5452
mIsClient(
5553
std::visit(rtc::overloaded{[](auto l) { return l->isActive(); },
5654
[](shared_ptr<TlsTransport> l) { return l->isClient(); }},
5755
lower)),
58-
mMaxOutstandingPings(maxOutstandingPings) {
56+
mMaxMessageSize(config.maxMessageSize.value_or(DEFAULT_MAX_MESSAGE_SIZE)),
57+
mMaxOutstandingPings(config.maxOutstandingPings.value_or(0)) {
5958

6059
onRecv(std::move(recvCallback));
6160

@@ -75,7 +74,10 @@ void WsTransport::start() {
7574
void WsTransport::stop() { close(); }
7675

7776
bool WsTransport::send(message_ptr message) {
78-
if (!message || state() != State::Connected)
77+
if (state() != State::Connected)
78+
throw std::runtime_error("WebSocket is not open");
79+
80+
if (!message)
7981
return false;
8082

8183
PLOG_VERBOSE << "Send size=" << message->size();
@@ -146,10 +148,22 @@ void WsTransport::incoming(message_ptr message) {
146148
sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, mIsClient});
147149
addOutstandingPing();
148150
} else {
149-
Frame frame;
150-
while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
151-
recvFrame(frame);
151+
if (mIgnoreLength > 0) {
152+
size_t len = std::min(mIgnoreLength, mBuffer.size());
152153
mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
154+
mIgnoreLength -= len;
155+
}
156+
if (mIgnoreLength == 0) {
157+
Frame frame;
158+
while (size_t len = parseFrame(mBuffer.data(), mBuffer.size(), frame)) {
159+
recvFrame(frame);
160+
if (len > mBuffer.size()) {
161+
mIgnoreLength = len - mBuffer.size();
162+
mBuffer.clear();
163+
break;
164+
}
165+
mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
166+
}
153167
}
154168
}
155169
}
@@ -229,7 +243,7 @@ bool WsTransport::sendHttpError(int code) {
229243
// | Payload Data continued ... |
230244
// +---------------------------------------------------------------+
231245

232-
size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
246+
size_t WsTransport::parseFrame(byte *buffer, size_t size, Frame &frame) {
233247
const byte *end = buffer + size;
234248
if (end - buffer < 2)
235249
return 0;
@@ -263,16 +277,25 @@ size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
263277
cur += 4;
264278
}
265279

266-
if (size_t(end - cur) < frame.length)
280+
const size_t maxControlFrameLength = 125;
281+
const size_t maxFrameLength = std::max(maxControlFrameLength, mMaxMessageSize);
282+
if (size_t(end - cur) < std::min(frame.length, maxFrameLength))
267283
return 0;
268284

285+
size_t length = frame.length;
286+
if (frame.length > maxFrameLength) {
287+
PLOG_WARNING << "WebSocket frame is too large (length=" << frame.length
288+
<< "), truncating it";
289+
frame.length = maxFrameLength;
290+
}
291+
269292
frame.payload = cur;
293+
270294
if (maskingKey)
271295
for (size_t i = 0; i < frame.length; ++i)
272296
frame.payload[i] ^= maskingKey[i % 4];
273-
cur += frame.length;
274297

275-
return size_t(cur - buffer);
298+
return frame.payload + length - buffer; // can be more than buffer size
276299
}
277300

278301
void WsTransport::recvFrame(const Frame &frame) {
@@ -282,32 +305,40 @@ void WsTransport::recvFrame(const Frame &frame) {
282305
switch (frame.opcode) {
283306
case TEXT_FRAME:
284307
case BINARY_FRAME: {
308+
size_t size = frame.length;
309+
if (size > mMaxMessageSize) {
310+
PLOG_WARNING << "WebSocket message is too large, truncating it";
311+
size = mMaxMessageSize;
312+
}
285313
if (!mPartial.empty()) {
286314
PLOG_WARNING << "WebSocket unfinished message: type="
287315
<< (mPartialOpcode == TEXT_FRAME ? "text" : "binary")
288-
<< ", length=" << mPartial.size();
316+
<< ", size=" << mPartial.size();
289317
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
290318
recv(make_message(mPartial.begin(), mPartial.end(), type));
291319
mPartial.clear();
292320
}
293321
mPartialOpcode = frame.opcode;
294322
if (frame.fin) {
295323
PLOG_DEBUG << "WebSocket finished message: type="
296-
<< (frame.opcode == TEXT_FRAME ? "text" : "binary")
297-
<< ", length=" << frame.length;
324+
<< (frame.opcode == TEXT_FRAME ? "text" : "binary") << ", size=" << size;
298325
auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary;
299-
recv(make_message(frame.payload, frame.payload + frame.length, type));
326+
recv(make_message(frame.payload, frame.payload + size, type));
300327
} else {
301-
mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
328+
mPartial.insert(mPartial.end(), frame.payload, frame.payload + size);
302329
}
303330
break;
304331
}
305332
case CONTINUATION: {
306333
mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
334+
if (mPartial.size() > mMaxMessageSize) {
335+
PLOG_WARNING << "WebSocket message is too large, truncating it";
336+
mPartial.resize(mMaxMessageSize);
337+
}
307338
if (frame.fin) {
308339
PLOG_DEBUG << "WebSocket finished message: type="
309340
<< (frame.opcode == TEXT_FRAME ? "text" : "binary")
310-
<< ", length=" << mPartial.size();
341+
<< ", size=" << mPartial.size();
311342
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
312343
recv(make_message(mPartial.begin(), mPartial.end(), type));
313344
mPartial.clear();

src/impl/wstransport.hpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "common.hpp"
1313
#include "transport.hpp"
14+
#include "configuration.hpp"
1415
#include "wshandshake.hpp"
1516

1617
#if RTC_ENABLE_WEBSOCKET
@@ -25,11 +26,12 @@ class TlsTransport;
2526

2627
class WsTransport final : public Transport, public std::enable_shared_from_this<WsTransport> {
2728
public:
28-
WsTransport(
29-
variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>
30-
lower,
31-
shared_ptr<WsHandshake> handshake, int maxOutstandingPings, message_callback recvCallback,
32-
state_callback stateCallback);
29+
using LowerTransport =
30+
variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>;
31+
32+
WsTransport(LowerTransport lower, shared_ptr<WsHandshake> handshake,
33+
const WebSocketConfiguration &config, message_callback recvCallback,
34+
state_callback stateCallback);
3335
~WsTransport();
3436

3537
void start() override;
@@ -62,19 +64,21 @@ class WsTransport final : public Transport, public std::enable_shared_from_this<
6264
bool sendHttpError(int code);
6365
bool sendHttpResponse();
6466

65-
size_t readFrame(byte *buffer, size_t size, Frame &frame);
67+
size_t parseFrame(byte *buffer, size_t size, Frame &frame);
6668
void recvFrame(const Frame &frame);
6769
bool sendFrame(const Frame &frame);
6870

6971
void addOutstandingPing();
7072

7173
const shared_ptr<WsHandshake> mHandshake;
7274
const bool mIsClient;
75+
const size_t mMaxMessageSize;
7376
const int mMaxOutstandingPings;
7477

7578
binary mBuffer;
7679
binary mPartial;
7780
Opcode mPartialOpcode;
81+
size_t mIgnoreLength = 0;
7882
std::mutex mSendMutex;
7983
int mOutstandingPings = 0;
8084
std::atomic<bool> mCloseSent = false;

0 commit comments

Comments
 (0)