Skip to content

Commit 12e00da

Browse files
Merge branch 'v0.18'
2 parents b716262 + c20732b commit 12e00da

File tree

4 files changed

+87
-41
lines changed

4 files changed

+87
-41
lines changed

src/impl/dtlstransport.cpp

+45-17
Original file line numberDiff line numberDiff line change
@@ -819,13 +819,18 @@ void DtlsTransport::start() {
819819
registerIncoming();
820820
changeState(State::Connecting);
821821

822-
size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
823-
SSL_set_mtu(mSsl, static_cast<unsigned int>(mtu));
824-
PLOG_VERBOSE << "DTLS MTU set to " << mtu;
822+
{
823+
std::lock_guard lock(mSslMutex);
825824

826-
// Initiate the handshake
827-
int ret = SSL_do_handshake(mSsl);
828-
openssl::check(mSsl, ret, "Handshake initiation failed");
825+
size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
826+
SSL_set_mtu(mSsl, static_cast<unsigned int>(mtu));
827+
PLOG_VERBOSE << "DTLS MTU set to " << mtu;
828+
829+
// Initiate the handshake
830+
int ret = SSL_do_handshake(mSsl);
831+
832+
openssl::check(mSsl, ret, "Handshake initiation failed");
833+
}
829834

830835
handleTimeout();
831836
}
@@ -843,8 +848,10 @@ bool DtlsTransport::send(message_ptr message) {
843848

844849
PLOG_VERBOSE << "Send size=" << message->size();
845850

851+
std::lock_guard lock(mSslMutex);
846852
mCurrentDscp = message->dscp;
847853
int ret = SSL_write(mSsl, message->data(), int(message->size()));
854+
848855
if (!openssl::check(mSsl, ret))
849856
return false;
850857

@@ -910,23 +917,39 @@ void DtlsTransport::doRecv() {
910917

911918
if (state() == State::Connecting) {
912919
// Continue the handshake
913-
int ret = SSL_do_handshake(mSsl);
914-
if (!openssl::check(mSsl, ret, "Handshake failed"))
915-
break;
920+
bool finished;
921+
{
922+
std::lock_guard lock(mSslMutex);
923+
int ret = SSL_do_handshake(mSsl);
916924

917-
if (SSL_is_init_finished(mSsl)) {
925+
if (!openssl::check(mSsl, ret, "Handshake failed"))
926+
break;
927+
928+
finished = (SSL_is_init_finished(mSsl) != 0);
929+
}
930+
if (finished) {
918931
// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
919932
// See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
920-
SSL_set_mtu(mSsl, bufferSize + 1);
933+
{
934+
std::lock_guard lock(mSslMutex);
935+
SSL_set_mtu(mSsl, bufferSize + 1);
936+
}
921937

922938
PLOG_INFO << "DTLS handshake finished";
923939
postHandshake();
924940
changeState(State::Connected);
925941
}
926-
} else {
927-
int ret = SSL_read(mSsl, buffer, bufferSize);
928-
if (!openssl::check(mSsl, ret))
929-
break;
942+
}
943+
944+
if (state() == State::Connected) {
945+
int ret;
946+
{
947+
std::lock_guard lock(mSslMutex);
948+
ret = SSL_read(mSsl, buffer, bufferSize);
949+
950+
if (!openssl::check(mSsl, ret))
951+
break;
952+
}
930953

931954
if (ret > 0)
932955
recv(make_message(buffer, buffer + ret));
@@ -937,8 +960,6 @@ void DtlsTransport::doRecv() {
937960
PLOG_ERROR << "DTLS recv: " << e.what();
938961
}
939962

940-
SSL_shutdown(mSsl);
941-
942963
if (state() == State::Connected) {
943964
PLOG_INFO << "DTLS closed";
944965
changeState(State::Disconnected);
@@ -947,9 +968,16 @@ void DtlsTransport::doRecv() {
947968
PLOG_ERROR << "DTLS handshake failed";
948969
changeState(State::Failed);
949970
}
971+
972+
{
973+
std::lock_guard lock(mSslMutex);
974+
SSL_shutdown(mSsl);
975+
}
950976
}
951977

952978
void DtlsTransport::handleTimeout() {
979+
std::lock_guard lock(mSslMutex);
980+
953981
// Warning: This function breaks the usual return value convention
954982
int ret = DTLSv1_handle_timeout(mSsl);
955983
if (ret < 0) {

src/impl/dtlstransport.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class DtlsTransport : public Transport, public std::enable_shared_from_this<Dtls
9999
SSL_CTX *mCtx = NULL;
100100
SSL *mSsl = NULL;
101101
BIO *mInBio, *mOutBio;
102+
std::mutex mSslMutex;
102103

103104
void handleTimeout();
104105

src/impl/tlstransport.cpp

+40-16
Original file line numberDiff line numberDiff line change
@@ -639,9 +639,9 @@ void TlsTransport::start() {
639639
changeState(State::Connecting);
640640

641641
// Initiate the handshake
642+
std::lock_guard lock(mSslMutex);
642643
int ret = SSL_do_handshake(mSsl);
643644
openssl::check(mSsl, ret, "Handshake initiation failed");
644-
645645
flushOutput();
646646
}
647647

@@ -661,6 +661,7 @@ bool TlsTransport::send(message_ptr message) {
661661

662662
PLOG_VERBOSE << "Send size=" << message->size();
663663

664+
std::lock_guard lock(mSslMutex);
664665
int ret = SSL_write(mSsl, message->data(), int(message->size()));
665666
if (!openssl::check(mSsl, ret))
666667
throw std::runtime_error("TLS send failed");
@@ -711,13 +712,18 @@ void TlsTransport::doRecv() {
711712

712713
if (state() == State::Connecting) {
713714
// Continue the handshake
714-
int ret = SSL_do_handshake(mSsl);
715-
if (!openssl::check(mSsl, ret, "Handshake failed"))
716-
break;
715+
bool finished;
716+
{
717+
std::lock_guard lock(mSslMutex);
718+
int ret = SSL_do_handshake(mSsl);
719+
if (!openssl::check(mSsl, ret, "Handshake failed"))
720+
break;
717721

718-
flushOutput();
722+
flushOutput();
723+
finished = (SSL_is_init_finished(mSsl) != 0);
724+
}
719725

720-
if (SSL_is_init_finished(mSsl)) {
726+
if (finished) {
721727
PLOG_INFO << "TLS handshake finished";
722728
changeState(State::Connected);
723729
postHandshake();
@@ -726,20 +732,32 @@ void TlsTransport::doRecv() {
726732

727733
if (state() == State::Connected) {
728734
int ret;
729-
while ((ret = SSL_read(mSsl, buffer, bufferSize)) > 0)
730-
recv(make_message(buffer, buffer + ret));
735+
while (true) {
736+
{
737+
std::lock_guard lock(mSslMutex);
738+
ret = SSL_read(mSsl, buffer, bufferSize);
739+
}
731740

732-
if (!openssl::check(mSsl, ret))
733-
break;
741+
if (ret > 0)
742+
recv(make_message(buffer, buffer + ret));
743+
else
744+
break;
745+
}
746+
747+
{
748+
std::lock_guard lock(mSslMutex);
749+
if (!openssl::check(mSsl, ret))
750+
break;
751+
752+
flushOutput(); // SSL_read() can also cause write operations
753+
}
734754
}
735755
}
736756

737757
} catch (const std::exception &e) {
738758
PLOG_ERROR << "TLS recv: " << e.what();
739759
}
740760

741-
SSL_shutdown(mSsl);
742-
743761
if (state() == State::Connected) {
744762
PLOG_INFO << "TLS closed";
745763
changeState(State::Disconnected);
@@ -748,15 +766,21 @@ void TlsTransport::doRecv() {
748766
PLOG_ERROR << "TLS handshake failed";
749767
changeState(State::Failed);
750768
}
769+
770+
{
771+
std::lock_guard lock(mSslMutex);
772+
SSL_shutdown(mSsl);
773+
}
751774
}
752775

753776
bool TlsTransport::flushOutput() {
777+
// Requires mSslMutex to be locked
778+
bool result = true;
754779
const size_t bufferSize = 4096;
755780
byte buffer[bufferSize];
756-
int ret;
757-
bool result = true;
758-
while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
759-
result = outgoing(make_message(buffer, buffer + ret));
781+
int len;
782+
while ((len = BIO_read(mOutBio, buffer, bufferSize)) > 0)
783+
result = outgoing(make_message(buffer, buffer + len));
760784

761785
return result;
762786
}

src/impl/tlstransport.hpp

+1-8
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,13 @@ class TlsTransport : public Transport, public std::enable_shared_from_this<TlsTr
8585
SSL_CTX *mCtx;
8686
SSL *mSsl;
8787
BIO *mInBio, *mOutBio;
88+
std::mutex mSslMutex;
8889

8990
bool flushOutput();
9091

91-
static BIO_METHOD *BioMethods;
9292
static int TransportExIndex;
93-
static std::mutex GlobalMutex;
9493

95-
static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
9694
static void InfoCallback(const SSL *ssl, int where, int ret);
97-
98-
static int BioMethodNew(BIO *bio);
99-
static int BioMethodFree(BIO *bio);
100-
static int BioMethodWrite(BIO *bio, const char *in, int inl);
101-
static long BioMethodCtrl(BIO *bio, int cmd, long num, void *ptr);
10295
#endif
10396
};
10497

0 commit comments

Comments
 (0)