@@ -43,19 +43,18 @@ using std::to_integer;
43
43
using std::to_string;
44
44
using std::chrono::system_clock;
45
45
46
- WsTransport::WsTransport (
47
- variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>, shared_ptr<TlsTransport>>
48
- lower,
49
- shared_ptr<WsHandshake> handshake, int maxOutstandingPings, message_callback recvCallback,
50
- state_callback stateCallback)
46
+ WsTransport::WsTransport (LowerTransport lower, shared_ptr<WsHandshake> handshake,
47
+ const WebSocketConfiguration &config, message_callback recvCallback,
48
+ state_callback stateCallback)
51
49
: Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
52
50
std::move (stateCallback)),
53
51
mHandshake (std::move(handshake)),
54
52
mIsClient (
55
53
std::visit (rtc::overloaded{[](auto l) { return l->isActive (); },
56
54
[](shared_ptr<TlsTransport> l) { return l->isClient (); }},
57
55
lower)),
58
- mMaxOutstandingPings(maxOutstandingPings) {
56
+ mMaxMessageSize(config.maxMessageSize.value_or(DEFAULT_MAX_MESSAGE_SIZE)),
57
+ mMaxOutstandingPings(config.maxOutstandingPings.value_or(0 )) {
59
58
60
59
onRecv (std::move (recvCallback));
61
60
@@ -75,7 +74,10 @@ void WsTransport::start() {
75
74
void WsTransport::stop () { close (); }
76
75
77
76
bool WsTransport::send (message_ptr message) {
78
- if (!message || state () != State::Connected)
77
+ if (state () != State::Connected)
78
+ throw std::runtime_error (" WebSocket is not open" );
79
+
80
+ if (!message)
79
81
return false ;
80
82
81
83
PLOG_VERBOSE << " Send size=" << message->size ();
@@ -146,10 +148,22 @@ void WsTransport::incoming(message_ptr message) {
146
148
sendFrame ({PING, reinterpret_cast <byte *>(&dummy), 4 , true , mIsClient });
147
149
addOutstandingPing ();
148
150
} else {
149
- Frame frame;
150
- while (size_t len = readFrame (mBuffer .data (), mBuffer .size (), frame)) {
151
- recvFrame (frame);
151
+ if (mIgnoreLength > 0 ) {
152
+ size_t len = std::min (mIgnoreLength , mBuffer .size ());
152
153
mBuffer .erase (mBuffer .begin (), mBuffer .begin () + len);
154
+ mIgnoreLength -= len;
155
+ }
156
+ if (mIgnoreLength == 0 ) {
157
+ Frame frame;
158
+ while (size_t len = parseFrame (mBuffer .data (), mBuffer .size (), frame)) {
159
+ recvFrame (frame);
160
+ if (len > mBuffer .size ()) {
161
+ mIgnoreLength = len - mBuffer .size ();
162
+ mBuffer .clear ();
163
+ break ;
164
+ }
165
+ mBuffer .erase (mBuffer .begin (), mBuffer .begin () + len);
166
+ }
153
167
}
154
168
}
155
169
}
@@ -229,7 +243,7 @@ bool WsTransport::sendHttpError(int code) {
229
243
// | Payload Data continued ... |
230
244
// +---------------------------------------------------------------+
231
245
232
- size_t WsTransport::readFrame (byte *buffer, size_t size, Frame &frame) {
246
+ size_t WsTransport::parseFrame (byte *buffer, size_t size, Frame &frame) {
233
247
const byte *end = buffer + size;
234
248
if (end - buffer < 2 )
235
249
return 0 ;
@@ -263,16 +277,25 @@ size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
263
277
cur += 4 ;
264
278
}
265
279
266
- if (size_t (end - cur) < frame.length )
280
+ const size_t maxControlFrameLength = 125 ;
281
+ const size_t maxFrameLength = std::max (maxControlFrameLength, mMaxMessageSize );
282
+ if (size_t (end - cur) < std::min (frame.length , maxFrameLength))
267
283
return 0 ;
268
284
285
+ size_t length = frame.length ;
286
+ if (frame.length > maxFrameLength) {
287
+ PLOG_WARNING << " WebSocket frame is too large (length=" << frame.length
288
+ << " ), truncating it" ;
289
+ frame.length = maxFrameLength;
290
+ }
291
+
269
292
frame.payload = cur;
293
+
270
294
if (maskingKey)
271
295
for (size_t i = 0 ; i < frame.length ; ++i)
272
296
frame.payload [i] ^= maskingKey[i % 4 ];
273
- cur += frame.length ;
274
297
275
- return size_t (cur - buffer);
298
+ return frame. payload + length - buffer; // can be more than buffer size
276
299
}
277
300
278
301
void WsTransport::recvFrame (const Frame &frame) {
@@ -282,32 +305,40 @@ void WsTransport::recvFrame(const Frame &frame) {
282
305
switch (frame.opcode ) {
283
306
case TEXT_FRAME:
284
307
case BINARY_FRAME: {
308
+ size_t size = frame.length ;
309
+ if (size > mMaxMessageSize ) {
310
+ PLOG_WARNING << " WebSocket message is too large, truncating it" ;
311
+ size = mMaxMessageSize ;
312
+ }
285
313
if (!mPartial .empty ()) {
286
314
PLOG_WARNING << " WebSocket unfinished message: type="
287
315
<< (mPartialOpcode == TEXT_FRAME ? " text" : " binary" )
288
- << " , length =" << mPartial .size ();
316
+ << " , size =" << mPartial .size ();
289
317
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
290
318
recv (make_message (mPartial .begin (), mPartial .end (), type));
291
319
mPartial .clear ();
292
320
}
293
321
mPartialOpcode = frame.opcode ;
294
322
if (frame.fin ) {
295
323
PLOG_DEBUG << " WebSocket finished message: type="
296
- << (frame.opcode == TEXT_FRAME ? " text" : " binary" )
297
- << " , length=" << frame.length ;
324
+ << (frame.opcode == TEXT_FRAME ? " text" : " binary" ) << " , size=" << size;
298
325
auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary;
299
- recv (make_message (frame.payload , frame.payload + frame. length , type));
326
+ recv (make_message (frame.payload , frame.payload + size , type));
300
327
} else {
301
- mPartial .insert (mPartial .end (), frame.payload , frame.payload + frame. length );
328
+ mPartial .insert (mPartial .end (), frame.payload , frame.payload + size );
302
329
}
303
330
break ;
304
331
}
305
332
case CONTINUATION: {
306
333
mPartial .insert (mPartial .end (), frame.payload , frame.payload + frame.length );
334
+ if (mPartial .size () > mMaxMessageSize ) {
335
+ PLOG_WARNING << " WebSocket message is too large, truncating it" ;
336
+ mPartial .resize (mMaxMessageSize );
337
+ }
307
338
if (frame.fin ) {
308
339
PLOG_DEBUG << " WebSocket finished message: type="
309
340
<< (frame.opcode == TEXT_FRAME ? " text" : " binary" )
310
- << " , length =" << mPartial .size ();
341
+ << " , size =" << mPartial .size ();
311
342
auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
312
343
recv (make_message (mPartial .begin (), mPartial .end (), type));
313
344
mPartial .clear ();
0 commit comments