-
-
Notifications
You must be signed in to change notification settings - Fork 392
/
Copy pathwshandshake.cpp
258 lines (196 loc) · 6.91 KB
/
wshandshake.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
/**
* Copyright (c) 2020-2021 Paul-Louis Ageneau
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
*/
#include "wshandshake.hpp"
#include "http.hpp"
#include "internals.hpp"
#include "sha.hpp"
#include "utils.hpp"
#if RTC_ENABLE_WEBSOCKET
#include <algorithm>
#include <chrono>
#include <climits>
#include <iostream>
#include <random>
#include <sstream>
using std::string;
namespace rtc::impl {
using std::to_string;
using std::chrono::system_clock;
WsHandshake::WsHandshake() {}
WsHandshake::WsHandshake(string host, string path, std::vector<string> protocols)
: mHost(std::move(host)), mPath(std::move(path)), mProtocols(std::move(protocols)) {
if (mHost.empty())
throw std::invalid_argument("WebSocket HTTP host cannot be empty");
if (mPath.empty())
throw std::invalid_argument("WebSocket HTTP path cannot be empty");
}
string WsHandshake::host() const {
std::unique_lock lock(mMutex);
return mHost;
}
string WsHandshake::path() const {
std::unique_lock lock(mMutex);
return mPath;
}
std::vector<string> WsHandshake::protocols() const {
std::unique_lock lock(mMutex);
return mProtocols;
}
string WsHandshake::generateHttpRequest() {
std::unique_lock lock(mMutex);
mKey = generateKey();
string out = "GET " + mPath +
" HTTP/1.1\r\n"
"Host: " +
mHost +
"\r\n"
"Connection: Upgrade\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Version: 13\r\n"
"Sec-WebSocket-Key: " +
mKey + "\r\n";
if (!mProtocols.empty())
out += "Sec-WebSocket-Protocol: " + utils::implode(mProtocols, ',') + "\r\n";
out += "\r\n";
return out;
}
string WsHandshake::generateHttpResponse() {
std::unique_lock lock(mMutex);
string out = "HTTP/1.1 101 Switching Protocols\r\n"
"Server: libdatachannel\r\n"
"Connection: Upgrade\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Accept: " +
computeAcceptKey(mKey) + "\r\n";
if (!mProtocols.empty())
out += "Sec-WebSocket-Protocol: " + utils::implode(mProtocols, ',') + "\r\n";
out += "\r\n";
return out;
}
namespace {
string GetHttpErrorName(int responseCode) {
switch (responseCode) {
case 400:
return "Bad Request";
case 404:
return "Not Found";
case 405:
return "Method Not Allowed";
case 426:
return "Upgrade Required";
case 500:
return "Internal Server Error";
default:
return "Error";
}
}
} // namespace
string WsHandshake::generateHttpError(int responseCode) {
std::unique_lock lock(mMutex);
const string error = to_string(responseCode) + " " + GetHttpErrorName(responseCode);
const string out = "HTTP/1.1 " + error +
"\r\n"
"Server: libdatachannel\r\n"
"Content-Type: text/plain\r\n"
"Content-Length: " +
to_string(error.size()) +
"\r\n"
"Access-Control-Allow-Origin: *\r\n\r\n" +
error;
return out;
}
size_t WsHandshake::parseHttpRequest(const byte *buffer, size_t size) {
if (!isHttpRequest(buffer, size))
throw RequestError("Invalid HTTP request for WebSocket", 400);
std::unique_lock lock(mMutex);
std::list<string> lines;
size_t length = parseHttpLines(buffer, size, lines);
if (length == 0)
return 0;
if (lines.empty())
throw RequestError("Invalid HTTP request for WebSocket", 400);
std::istringstream requestLine(std::move(lines.front()));
lines.pop_front();
string method, path, protocol;
requestLine >> method >> path >> protocol;
PLOG_DEBUG << "WebSocket request method=\"" << method << "\", path=\"" << path << "\"";
if (method != "GET")
throw RequestError("Invalid request method \"" + method + "\" for WebSocket", 405);
mPath = std::move(path);
auto headers = parseHttpHeaders(lines);
auto h = headers.find("host");
if (h == headers.end())
throw RequestError("WebSocket host header missing in request", 400);
mHost = std::move(h->second);
h = headers.find("upgrade");
if (h == headers.end())
throw RequestError("WebSocket upgrade header missing in request", 426);
string upgrade;
std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
[](char c) { return std::tolower(c); });
if (upgrade != "websocket")
throw RequestError("WebSocket upgrade header mismatching", 426);
h = headers.find("sec-websocket-key");
if (h == headers.end())
throw RequestError("WebSocket key header missing in request", 400);
mKey = std::move(h->second);
h = headers.find("sec-websocket-protocol");
if (h != headers.end())
mProtocols = utils::explode(h->second, ',');
return length;
}
size_t WsHandshake::parseHttpResponse(const byte *buffer, size_t size) {
std::unique_lock lock(mMutex);
std::list<string> lines;
size_t length = parseHttpLines(buffer, size, lines);
if (length == 0)
return 0;
if (lines.empty())
throw Error("Invalid HTTP response for WebSocket");
std::istringstream status(std::move(lines.front()));
lines.pop_front();
string protocol;
unsigned int code = 0;
status >> protocol >> code;
PLOG_DEBUG << "WebSocket response code=" << code;
if (code != 101)
throw std::runtime_error("Unexpected response code " + to_string(code) + " for WebSocket");
auto headers = parseHttpHeaders(lines);
auto h = headers.find("upgrade");
if (h == headers.end())
throw Error("WebSocket update header missing");
string upgrade;
std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
[](char c) { return std::tolower(c); });
if (upgrade != "websocket")
throw Error("WebSocket update header mismatching");
h = headers.find("sec-websocket-accept");
if (h == headers.end())
throw Error("WebSocket accept header missing");
if (h->second != computeAcceptKey(mKey))
throw Error("WebSocket accept header is invalid");
return length;
}
string WsHandshake::generateKey() {
// RFC 6455: The request MUST include a header field with the name Sec-WebSocket-Key. The value
// of this header field MUST be a nonce consisting of a randomly selected 16-byte value that has
// been base64-encoded. [...] The nonce MUST be selected randomly for each connection.
binary key(16);
auto k = reinterpret_cast<uint8_t *>(key.data());
std::generate(k, k + key.size(), utils::random_bytes_engine());
return utils::base64_encode(key);
}
string WsHandshake::computeAcceptKey(const string &key) {
return utils::base64_encode(Sha1(string(key) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
}
WsHandshake::Error::Error(const string &w) : std::runtime_error(w) {}
WsHandshake::RequestError::RequestError(const string &w, int responseCode)
: Error(w), mResponseCode(responseCode) {}
int WsHandshake::RequestError::RequestError::responseCode() const { return mResponseCode; }
} // namespace rtc::impl
#endif