Skip to content

Commit e1ab84d

Browse files
Refactor OpenSSL logic
1 parent 33f42ae commit e1ab84d

File tree

4 files changed

+79
-66
lines changed

4 files changed

+79
-66
lines changed

src/impl/dtlstransport.cpp

+25-19
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,7 @@ void DtlsTransport::start() {
819819
registerIncoming();
820820
changeState(State::Connecting);
821821

822+
int ret, err;
822823
{
823824
std::lock_guard lock(mSslMutex);
824825

@@ -827,11 +828,12 @@ void DtlsTransport::start() {
827828
PLOG_VERBOSE << "DTLS MTU set to " << mtu;
828829

829830
// Initiate the handshake
830-
int ret = SSL_do_handshake(mSsl);
831-
832-
openssl::check(mSsl, ret, "Handshake initiation failed");
831+
ret = SSL_do_handshake(mSsl);
832+
err = SSL_get_error(mSsl, ret);
833833
}
834834

835+
openssl::check_error(err, "Handshake failed");
836+
835837
handleTimeout();
836838
}
837839

@@ -848,11 +850,15 @@ bool DtlsTransport::send(message_ptr message) {
848850

849851
PLOG_VERBOSE << "Send size=" << message->size();
850852

851-
std::lock_guard lock(mSslMutex);
852-
mCurrentDscp = message->dscp;
853-
int ret = SSL_write(mSsl, message->data(), int(message->size()));
853+
int ret, err;
854+
{
855+
std::lock_guard lock(mSslMutex);
856+
mCurrentDscp = message->dscp;
857+
ret = SSL_write(mSsl, message->data(), int(message->size()));
858+
err = SSL_get_error(mSsl, ret);
859+
}
854860

855-
if (!openssl::check(mSsl, ret))
861+
if (!openssl::check_error(err))
856862
return false;
857863

858864
return mOutgoingResult;
@@ -917,17 +923,14 @@ void DtlsTransport::doRecv() {
917923

918924
if (state() == State::Connecting) {
919925
// Continue the handshake
920-
bool finished;
926+
int ret, err;
921927
{
922928
std::lock_guard lock(mSslMutex);
923-
int ret = SSL_do_handshake(mSsl);
924-
925-
if (!openssl::check(mSsl, ret, "Handshake failed"))
926-
break;
927-
928-
finished = (SSL_is_init_finished(mSsl) != 0);
929+
ret = SSL_do_handshake(mSsl);
930+
err = SSL_get_error(mSsl, ret);
929931
}
930-
if (finished) {
932+
933+
if (openssl::check_error(err, "Handshake failed")) {
931934
// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
932935
// See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
933936
{
@@ -942,16 +945,19 @@ void DtlsTransport::doRecv() {
942945
}
943946

944947
if (state() == State::Connected) {
945-
int ret;
948+
int ret, err;
946949
{
947950
std::lock_guard lock(mSslMutex);
948951
ret = SSL_read(mSsl, buffer, bufferSize);
952+
err = SSL_get_error(mSsl, ret);
953+
}
949954

950-
if (!openssl::check(mSsl, ret))
951-
break;
955+
if (err == SSL_ERROR_ZERO_RETURN) {
956+
PLOG_DEBUG << "TLS connection cleanly closed";
957+
break;
952958
}
953959

954-
if (ret > 0)
960+
if (openssl::check_error(err))
955961
recv(make_message(buffer, buffer + ret));
956962
}
957963
}

src/impl/tls.cpp

+16-21
Original file line numberDiff line numberDiff line change
@@ -177,34 +177,29 @@ bool check(int success, const string &message) {
177177
if (success > 0)
178178
return true;
179179

180-
string str = message;
181-
if (last_error != 0)
182-
str += ": " + error_string(last_error);
183-
184-
throw std::runtime_error(str);
180+
throw std::runtime_error(message + (last_error != 0 ? ": " + error_string(last_error) : ""));
185181
}
186182

187-
// Return false on EOF
188-
bool check(SSL *ssl, int ret, const string &message) {
183+
// Return false on recoverable error
184+
bool check_error(int err, const string &message) {
189185
unsigned long last_error = ERR_peek_last_error();
190186
ERR_clear_error();
191187

192-
int err = SSL_get_error(ssl, ret);
193-
if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
188+
if (err == SSL_ERROR_NONE)
194189
return true;
195-
}
196-
if (err == SSL_ERROR_ZERO_RETURN) {
197-
return false;
198-
}
199190

200-
string str = message;
201-
if (err == SSL_ERROR_SYSCALL) {
202-
str += ": fatal I/O error";
203-
} else if (err == SSL_ERROR_SSL) {
204-
if (last_error != 0)
205-
str += ": " + error_string(last_error);
206-
}
207-
throw std::runtime_error(str);
191+
if (err == SSL_ERROR_ZERO_RETURN)
192+
throw std::runtime_error(message + ": peer closed connection");
193+
194+
if (err == SSL_ERROR_SYSCALL)
195+
throw std::runtime_error(message + ": fatal I/O error");
196+
197+
if (err == SSL_ERROR_SSL)
198+
throw std::runtime_error(message +
199+
(last_error != 0 ? ": " + error_string(last_error) : ""));
200+
201+
// SSL_ERROR_WANT_READ and SSL_ERROR_WANT_WRITE end up here
202+
return false;
208203
}
209204

210205
BIO *BIO_new_from_file(const string &filename) {

src/impl/tls.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ void init();
8585
string error_string(unsigned long error);
8686

8787
bool check(int success, const string &message = "OpenSSL error");
88-
bool check(SSL *ssl, int ret, const string &message = "OpenSSL error");
88+
bool check_error(int err, const string &message = "OpenSSL error");
8989

9090
BIO *BIO_new_from_file(const string &filename);
9191

src/impl/tlstransport.cpp

+37-25
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,8 @@ bool TlsTransport::send(message_ptr message) {
386386
int(message->size()));
387387
} while (ret == MBEDTLS_ERR_SSL_WANT_WRITE);
388388

389-
mbedtls::check(ret);
389+
if (!mbedtls::check(ret))
390+
throw std::runtime_error("TLS send failed");
390391

391392
return mOutgoingResult;
392393
}
@@ -639,10 +640,15 @@ void TlsTransport::start() {
639640
changeState(State::Connecting);
640641

641642
// Initiate the handshake
642-
std::lock_guard lock(mSslMutex);
643-
int ret = SSL_do_handshake(mSsl);
644-
openssl::check(mSsl, ret, "Handshake initiation failed");
645-
flushOutput();
643+
int ret, err;
644+
{
645+
std::lock_guard lock(mSslMutex);
646+
ret = SSL_do_handshake(mSsl);
647+
err = SSL_get_error(mSsl, ret);
648+
flushOutput();
649+
}
650+
651+
openssl::check_error(err, "Handshake failed");
646652
}
647653

648654
void TlsTransport::stop() {
@@ -661,12 +667,19 @@ bool TlsTransport::send(message_ptr message) {
661667

662668
PLOG_VERBOSE << "Send size=" << message->size();
663669

664-
std::lock_guard lock(mSslMutex);
665-
int ret = SSL_write(mSsl, message->data(), int(message->size()));
666-
if (!openssl::check(mSsl, ret))
670+
int err;
671+
bool result;
672+
{
673+
std::lock_guard lock(mSslMutex);
674+
int ret = SSL_write(mSsl, message->data(), int(message->size()));
675+
err = SSL_get_error(mSsl, ret);
676+
result = flushOutput();
677+
}
678+
679+
if (!openssl::check_error(err))
667680
throw std::runtime_error("TLS send failed");
668681

669-
return flushOutput();
682+
return result;
670683
}
671684

672685
void TlsTransport::incoming(message_ptr message) {
@@ -698,7 +711,7 @@ void TlsTransport::doRecv() {
698711
const size_t bufferSize = 4096;
699712
byte buffer[bufferSize];
700713

701-
// Process incoming messages
714+
// Read incoming messages
702715
while (mIncomingQueue.running()) {
703716
auto next = mIncomingQueue.pop();
704717
if (!next)
@@ -712,44 +725,43 @@ void TlsTransport::doRecv() {
712725

713726
if (state() == State::Connecting) {
714727
// Continue the handshake
715-
bool finished;
728+
int ret, err;
716729
{
717730
std::lock_guard lock(mSslMutex);
718-
int ret = SSL_do_handshake(mSsl);
719-
if (!openssl::check(mSsl, ret, "Handshake failed"))
720-
break;
721-
731+
ret = SSL_do_handshake(mSsl);
732+
err = SSL_get_error(mSsl, ret);
722733
flushOutput();
723-
finished = (SSL_is_init_finished(mSsl) != 0);
724734
}
725735

726-
if (finished) {
736+
if (openssl::check_error(err, "Handshake failed")) {
727737
PLOG_INFO << "TLS handshake finished";
728738
changeState(State::Connected);
729739
postHandshake();
730740
}
731741
}
732742

733743
if (state() == State::Connected) {
734-
int ret;
744+
int ret, err;
735745
while (true) {
736746
{
737747
std::lock_guard lock(mSslMutex);
738748
ret = SSL_read(mSsl, buffer, bufferSize);
749+
err = SSL_get_error(mSsl, ret);
750+
flushOutput(); // SSL_read() can also cause write operations
739751
}
740752

741-
if (ret > 0)
753+
if (err == SSL_ERROR_ZERO_RETURN)
754+
break;
755+
756+
if (openssl::check_error(err))
742757
recv(make_message(buffer, buffer + ret));
743758
else
744759
break;
745760
}
746761

747-
{
748-
std::lock_guard lock(mSslMutex);
749-
if (!openssl::check(mSsl, ret))
750-
break;
751-
752-
flushOutput(); // SSL_read() can also cause write operations
762+
if (err == SSL_ERROR_ZERO_RETURN) {
763+
PLOG_DEBUG << "TLS connection cleanly closed";
764+
break; // No more data can be read
753765
}
754766
}
755767
}

0 commit comments

Comments
 (0)