Skip to content

Commit 330225b

Browse files
committed
Use Optional<ExchangeHandle> variable instead of ExchangeContext in PairingSession.
This ensures that the ExchangeContext for the session is automatically reference counted without explicit manual Retain() and Release() calls. As a result, the ExchangeContext is held until PairingSession::Clear() gets called that, in turn, calls ClearValue() on the Optional to internally call Release() on the underlying target. Fixes Issue project-chip#32498.
1 parent 3dc2320 commit 330225b

File tree

4 files changed

+77
-63
lines changed

4 files changed

+77
-63
lines changed

src/protocols/secure_channel/CASESession.cpp

+30-23
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,8 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric
514514

515515
// We are setting the exchange context specifically before checking for error.
516516
// This is to make sure the exchange will get closed if Init() returned an error.
517-
mExchangeCtxt = exchangeCtxt;
517+
ExchangeHandle ecHandle(*exchangeCtxt);
518+
mExchangeCtxt.SetValue(ecHandle);
518519

519520
// From here onwards, let's go to exit on error, as some state might have already
520521
// been initialized
@@ -527,7 +528,7 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric
527528
mSessionResumptionStorage = sessionResumptionStorage;
528529
mLocalMRPConfig = mrpLocalConfig.ValueOr(GetDefaultMRPConfig());
529530

530-
mExchangeCtxt->UseSuggestedResponseTimeout(kExpectedSigma1ProcessingTime);
531+
mExchangeCtxt.Value()->UseSuggestedResponseTimeout(kExpectedSigma1ProcessingTime);
531532
mPeerNodeId = peerScopedNodeId.GetNodeId();
532533
mLocalNodeId = fabricInfo->GetNodeId();
533534

@@ -549,7 +550,8 @@ void CASESession::OnResponseTimeout(ExchangeContext * ec)
549550
{
550551
MATTER_TRACE_SCOPE("OnResponseTimeout", "CASESession");
551552
VerifyOrReturn(ec != nullptr, ChipLogError(SecureChannel, "CASESession::OnResponseTimeout was called by null exchange"));
552-
VerifyOrReturn(mExchangeCtxt == ec, ChipLogError(SecureChannel, "CASESession::OnResponseTimeout exchange doesn't match"));
553+
VerifyOrReturn(&mExchangeCtxt.Value().Get() == ec,
554+
ChipLogError(SecureChannel, "CASESession::OnResponseTimeout exchange doesn't match"));
553555
ChipLogError(SecureChannel, "CASESession timed out while waiting for a response from the peer. Current state was %u",
554556
to_underlying(mState));
555557
MATTER_TRACE_COUNTER("CASETimeout");
@@ -735,8 +737,8 @@ CHIP_ERROR CASESession::SendSigma1()
735737
ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msg_R1->Start(), msg_R1->DataLength() }));
736738

737739
// Call delegate to send the msg to peer
738-
ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma1, std::move(msg_R1),
739-
SendFlags(SendMessageFlags::kExpectResponse)));
740+
ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma1, std::move(msg_R1),
741+
SendFlags(SendMessageFlags::kExpectResponse)));
740742

741743
mState = resuming ? State::kSentSigma1Resume : State::kSentSigma1;
742744

@@ -959,8 +961,9 @@ CHIP_ERROR CASESession::SendSigma2Resume()
959961
ReturnErrorOnFailure(tlvWriter.Finalize(&msg_R2_resume));
960962

961963
// Call delegate to send the msg to peer
962-
ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2Resume, std::move(msg_R2_resume),
963-
SendFlags(SendMessageFlags::kExpectResponse)));
964+
ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2Resume,
965+
std::move(msg_R2_resume),
966+
SendFlags(SendMessageFlags::kExpectResponse)));
964967

965968
mState = State::kSentSigma2Resume;
966969

@@ -1096,8 +1099,8 @@ CHIP_ERROR CASESession::SendSigma2()
10961099
ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msg_R2->Start(), msg_R2->DataLength() }));
10971100

10981101
// Call delegate to send the msg to peer
1099-
ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2, std::move(msg_R2),
1100-
SendFlags(SendMessageFlags::kExpectResponse)));
1102+
ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2, std::move(msg_R2),
1103+
SendFlags(SendMessageFlags::kExpectResponse)));
11011104

11021105
mState = State::kSentSigma2;
11031106

@@ -1148,7 +1151,8 @@ CHIP_ERROR CASESession::HandleSigma2Resume(System::PacketBufferHandle && msg)
11481151
if (tlvReader.Next() != CHIP_END_OF_TLV)
11491152
{
11501153
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(4), tlvReader));
1151-
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
1154+
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
1155+
GetRemoteSessionParameters());
11521156
}
11531157

11541158
ChipLogDetail(SecureChannel, "Peer assigned session session ID %d", responderSessionId);
@@ -1341,7 +1345,8 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg)
13411345
if (tlvReader.Next() != CHIP_END_OF_TLV)
13421346
{
13431347
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(kTag_Sigma2_ResponderMRPParams), tlvReader));
1344-
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
1348+
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
1349+
GetRemoteSessionParameters());
13451350
}
13461351

13471352
exit:
@@ -1410,7 +1415,7 @@ CHIP_ERROR CASESession::SendSigma3a()
14101415
{
14111416
SuccessOrExit(err = helper->ScheduleWork());
14121417
mSendSigma3Helper = helper;
1413-
mExchangeCtxt->WillSendMessage();
1418+
mExchangeCtxt.Value()->WillSendMessage();
14141419
mState = State::kSendSigma3Pending;
14151420
}
14161421
else
@@ -1537,8 +1542,8 @@ CHIP_ERROR CASESession::SendSigma3c(SendSigma3Data & data, CHIP_ERROR status)
15371542
SuccessOrExit(err);
15381543

15391544
// Call delegate to send the Msg3 to peer
1540-
err = mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma3, std::move(msg_R3),
1541-
SendFlags(SendMessageFlags::kExpectResponse));
1545+
err = mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma3, std::move(msg_R3),
1546+
SendFlags(SendMessageFlags::kExpectResponse));
15421547
SuccessOrExit(err);
15431548

15441549
ChipLogProgress(SecureChannel, "Sent Sigma3 msg");
@@ -1704,7 +1709,7 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg)
17041709

17051710
SuccessOrExit(err = helper->ScheduleWork());
17061711
mHandleSigma3Helper = helper;
1707-
mExchangeCtxt->WillSendMessage();
1712+
mExchangeCtxt.Value()->WillSendMessage();
17081713
mState = State::kHandleSigma3Pending;
17091714
}
17101715

@@ -2036,7 +2041,8 @@ CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader,
20362041
if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kInitiatorMRPParamsTag))
20372042
{
20382043
ReturnErrorOnFailure(DecodeMRPParametersIfPresent(TLV::ContextTag(kInitiatorMRPParamsTag), tlvReader));
2039-
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
2044+
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
2045+
GetRemoteSessionParameters());
20402046
err = tlvReader.Next();
20412047
}
20422048

@@ -2084,26 +2090,27 @@ CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader,
20842090
return CHIP_NO_ERROR;
20852091
}
20862092

2087-
CHIP_ERROR CASESession::ValidateReceivedMessage(ExchangeContext * ec, const PayloadHeader & payloadHeader,
2093+
CHIP_ERROR CASESession::ValidateReceivedMessage(chip::Messaging::ExchangeContext * ec, const PayloadHeader & payloadHeader,
20882094
const System::PacketBufferHandle & msg)
20892095
{
20902096
VerifyOrReturnError(ec != nullptr, CHIP_ERROR_INVALID_ARGUMENT);
20912097

20922098
// mExchangeCtxt can be nullptr if this is the first message (CASE_Sigma1) received by CASESession
20932099
// via UnsolicitedMessageHandler. The exchange context is allocated by exchange manager and provided
20942100
// to the handler (CASESession object).
2095-
if (mExchangeCtxt != nullptr)
2101+
if (mExchangeCtxt.HasValue())
20962102
{
2097-
if (mExchangeCtxt != ec)
2103+
if (&mExchangeCtxt.Value().Get() != ec)
20982104
{
20992105
ReturnErrorOnFailure(CHIP_ERROR_INVALID_ARGUMENT);
21002106
}
21012107
}
21022108
else
21032109
{
2104-
mExchangeCtxt = ec;
2110+
ExchangeHandle ecHandle(*ec);
2111+
mExchangeCtxt.SetValue(ecHandle);
21052112
}
2106-
mExchangeCtxt->UseSuggestedResponseTimeout(kExpectedHighProcessingTime);
2113+
mExchangeCtxt.Value()->UseSuggestedResponseTimeout(kExpectedHighProcessingTime);
21072114

21082115
VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);
21092116
return CHIP_NO_ERROR;
@@ -2128,7 +2135,7 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea
21282135
//
21292136
// Should you need to resume the CASESession, you could theoretically pass along the msg to a callback that gets
21302137
// registered when setting mStopHandshakeAtState.
2131-
mExchangeCtxt->WillSendMessage();
2138+
mExchangeCtxt.Value()->WillSendMessage();
21322139
return CHIP_NO_ERROR;
21332140
}
21342141
#endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST
@@ -2138,7 +2145,7 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea
21382145
msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2Resume ||
21392146
msgType == Protocols::SecureChannel::MsgType::CASE_Sigma3)
21402147
{
2141-
SuccessOrExit(err = mExchangeCtxt->FlushAcks());
2148+
SuccessOrExit(err = mExchangeCtxt.Value()->FlushAcks());
21422149
}
21432150
#endif // CHIP_CONFIG_SLOW_CRYPTO
21442151

src/protocols/secure_channel/PASESession.cpp

+27-20
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,18 @@ CHIP_ERROR PASESession::Pair(SessionManager & sessionManager, uint32_t peerSetUp
213213
{
214214
MATTER_TRACE_SCOPE("Pair", "PASESession");
215215
ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT);
216+
ExchangeHandle ecHandle(*exchangeCtxt);
216217
CHIP_ERROR err = Init(sessionManager, peerSetUpPINCode, delegate);
217218
SuccessOrExit(err);
218219

219220
mRole = CryptoContext::SessionRole::kInitiator;
220221

221-
mExchangeCtxt = exchangeCtxt;
222+
mExchangeCtxt.SetValue(ecHandle);
222223

223224
// When commissioning starts, the peer is assumed to be active.
224-
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->MarkActiveRx();
225+
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->MarkActiveRx();
225226

226-
mExchangeCtxt->UseSuggestedResponseTimeout(kExpectedLowProcessingTime);
227+
mExchangeCtxt.Value()->UseSuggestedResponseTimeout(kExpectedLowProcessingTime);
227228

228229
mLocalMRPConfig = mrpLocalConfig.ValueOr(GetDefaultMRPConfig());
229230

@@ -244,7 +245,7 @@ void PASESession::OnResponseTimeout(ExchangeContext * ec)
244245
{
245246
MATTER_TRACE_SCOPE("OnResponseTimeout", "PASESession");
246247
VerifyOrReturn(ec != nullptr, ChipLogError(SecureChannel, "PASESession::OnResponseTimeout was called by null exchange"));
247-
VerifyOrReturn(mExchangeCtxt == nullptr || mExchangeCtxt == ec,
248+
VerifyOrReturn(!mExchangeCtxt.HasValue() || &mExchangeCtxt.Value().Get() == ec,
248249
ChipLogError(SecureChannel, "PASESession::OnResponseTimeout exchange doesn't match"));
249250
// If we were waiting for something, mNextExpectedMsg had better have a value.
250251
ChipLogError(SecureChannel, "PASESession timed out while waiting for a response from the peer. Expected message type was %u",
@@ -308,8 +309,8 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest()
308309
// Update commissioning hash with the pbkdf2 param request that's being sent.
309310
ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ req->Start(), req->DataLength() }));
310311

311-
ReturnErrorOnFailure(
312-
mExchangeCtxt->SendMessage(MsgType::PBKDFParamRequest, std::move(req), SendFlags(SendMessageFlags::kExpectResponse)));
312+
ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(MsgType::PBKDFParamRequest, std::move(req),
313+
SendFlags(SendMessageFlags::kExpectResponse)));
313314

314315
mNextExpectedMsg.SetValue(MsgType::PBKDFParamResponse);
315316

@@ -364,7 +365,8 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && ms
364365
if (tlvReader.Next() != CHIP_END_OF_TLV)
365366
{
366367
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
367-
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
368+
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
369+
GetRemoteSessionParameters());
368370
}
369371

370372
err = SendPBKDFParamResponse(ByteSpan(initiatorRandom), hasPBKDFParameters);
@@ -428,8 +430,8 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in
428430
ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ resp->Start(), resp->DataLength() }));
429431
ReturnErrorOnFailure(SetupSpake2p());
430432

431-
ReturnErrorOnFailure(
432-
mExchangeCtxt->SendMessage(MsgType::PBKDFParamResponse, std::move(resp), SendFlags(SendMessageFlags::kExpectResponse)));
433+
ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(MsgType::PBKDFParamResponse, std::move(resp),
434+
SendFlags(SendMessageFlags::kExpectResponse)));
433435
ChipLogDetail(SecureChannel, "Sent PBKDF param response");
434436

435437
mNextExpectedMsg.SetValue(MsgType::PASE_Pake1);
@@ -483,7 +485,8 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m
483485
if (tlvReader.Next() != CHIP_END_OF_TLV)
484486
{
485487
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
486-
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
488+
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
489+
GetRemoteSessionParameters());
487490
}
488491

489492
// TODO - Add a unit test that exercises mHavePBKDFParameters path
@@ -508,7 +511,8 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m
508511
if (tlvReader.Next() != CHIP_END_OF_TLV)
509512
{
510513
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
511-
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
514+
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(
515+
GetRemoteSessionParameters());
512516
}
513517
}
514518

@@ -558,7 +562,7 @@ CHIP_ERROR PASESession::SendMsg1()
558562
ReturnErrorOnFailure(tlvWriter.Finalize(&msg));
559563

560564
ReturnErrorOnFailure(
561-
mExchangeCtxt->SendMessage(MsgType::PASE_Pake1, std::move(msg), SendFlags(SendMessageFlags::kExpectResponse)));
565+
mExchangeCtxt.Value()->SendMessage(MsgType::PASE_Pake1, std::move(msg), SendFlags(SendMessageFlags::kExpectResponse)));
562566
ChipLogDetail(SecureChannel, "Sent spake2p msg1");
563567

564568
mNextExpectedMsg.SetValue(MsgType::PASE_Pake2);
@@ -620,7 +624,8 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(System::PacketBufferHandle && ms
620624
SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType));
621625
SuccessOrExit(err = tlvWriter.Finalize(&msg2));
622626

623-
err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake2, std::move(msg2), SendFlags(SendMessageFlags::kExpectResponse));
627+
err =
628+
mExchangeCtxt.Value()->SendMessage(MsgType::PASE_Pake2, std::move(msg2), SendFlags(SendMessageFlags::kExpectResponse));
624629
SuccessOrExit(err);
625630

626631
mNextExpectedMsg.SetValue(MsgType::PASE_Pake3);
@@ -696,7 +701,8 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(System::PacketBufferHandle && ms
696701
SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType));
697702
SuccessOrExit(err = tlvWriter.Finalize(&msg3));
698703

699-
err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake3, std::move(msg3), SendFlags(SendMessageFlags::kExpectResponse));
704+
err =
705+
mExchangeCtxt.Value()->SendMessage(MsgType::PASE_Pake3, std::move(msg3), SendFlags(SendMessageFlags::kExpectResponse));
700706
SuccessOrExit(err);
701707

702708
mNextExpectedMsg.SetValue(MsgType::StatusReport);
@@ -786,25 +792,26 @@ CHIP_ERROR PASESession::ValidateReceivedMessage(ExchangeContext * exchange, cons
786792
// mExchangeCtxt can be nullptr if this is the first message (PBKDFParamRequest) received by PASESession
787793
// via UnsolicitedMessageHandler. The exchange context is allocated by exchange manager and provided
788794
// to the handler (PASESession object).
789-
if (mExchangeCtxt != nullptr)
795+
if (mExchangeCtxt.HasValue())
790796
{
791-
if (mExchangeCtxt != exchange)
797+
if (&mExchangeCtxt.Value().Get() != exchange)
792798
{
793799
ReturnErrorOnFailure(CHIP_ERROR_INVALID_ARGUMENT);
794800
}
795801
}
796802
else
797803
{
798-
mExchangeCtxt = exchange;
804+
ExchangeHandle ecHandle(*exchange);
805+
mExchangeCtxt.SetValue(ecHandle);
799806
}
800807

801-
if (!mExchangeCtxt->GetSessionHandle()->IsUnauthenticatedSession())
808+
if (!mExchangeCtxt.Value()->GetSessionHandle()->IsUnauthenticatedSession())
802809
{
803810
ChipLogError(SecureChannel, "PASESession received PBKDFParamRequest over encrypted session. Ignoring.");
804811
return CHIP_ERROR_INCORRECT_STATE;
805812
}
806813

807-
mExchangeCtxt->UseSuggestedResponseTimeout(kExpectedHighProcessingTime);
814+
mExchangeCtxt.Value()->UseSuggestedResponseTimeout(kExpectedHighProcessingTime);
808815

809816
VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);
810817
VerifyOrReturnError((mNextExpectedMsg.HasValue() && payloadHeader.HasMessageType(mNextExpectedMsg.Value())) ||
@@ -833,7 +840,7 @@ CHIP_ERROR PASESession::OnMessageReceived(ExchangeContext * exchange, const Payl
833840
if (msgType == MsgType::PBKDFParamRequest || msgType == MsgType::PBKDFParamResponse || msgType == MsgType::PASE_Pake1 ||
834841
msgType == MsgType::PASE_Pake2 || msgType == MsgType::PASE_Pake3)
835842
{
836-
SuccessOrExit(err = mExchangeCtxt->FlushAcks());
843+
SuccessOrExit(err = mExchangeCtxt.Value()->FlushAcks());
837844
}
838845
#endif // CHIP_CONFIG_SLOW_CRYPTO
839846

0 commit comments

Comments
 (0)