diff --git a/src/app/CommandSender.cpp b/src/app/CommandSender.cpp index fe263eb8923cc7..dbf1e864cd3ac1 100644 --- a/src/app/CommandSender.cpp +++ b/src/app/CommandSender.cpp @@ -540,13 +540,18 @@ CHIP_ERROR CommandSender::FinishCommand(FinishCommandParameters & aFinishCommand CHIP_ERROR CommandSender::AddRequestData(const CommandPathParams & aCommandPath, const DataModel::EncodableToTLV & aEncodable, AddRequestDataParameters & aAddRequestDataParams) { + ReturnErrorOnFailure(AllocateBuffer()); + + RollbackInvokeRequest rollback(*this); PrepareCommandParameters prepareCommandParams(aAddRequestDataParams); ReturnErrorOnFailure(PrepareCommand(aCommandPath, prepareCommandParams)); TLV::TLVWriter * writer = GetCommandDataIBTLVWriter(); VerifyOrReturnError(writer != nullptr, CHIP_ERROR_INCORRECT_STATE); ReturnErrorOnFailure(aEncodable.EncodeTo(*writer, TLV::ContextTag(CommandDataIB::Tag::kFields))); FinishCommandParameters finishCommandParams(aAddRequestDataParams); - return FinishCommand(finishCommandParams); + ReturnErrorOnFailure(FinishCommand(finishCommandParams)); + rollback.DisableAutomaticRollback(); + return CHIP_NO_ERROR; } CHIP_ERROR CommandSender::FinishCommandInternal(FinishCommandParameters & aFinishCommandParams) @@ -657,5 +662,34 @@ void CommandSender::MoveToState(const State aTargetState) ChipLogDetail(DataManagement, "ICR moving to [%10.10s]", GetStateStr()); } +CommandSender::RollbackInvokeRequest::RollbackInvokeRequest(CommandSender & aCommandSender) : mCommandSender(aCommandSender) +{ + VerifyOrReturn(mCommandSender.mBufferAllocated); + VerifyOrReturn(mCommandSender.mState == State::Idle || mCommandSender.mState == State::AddedCommand); + VerifyOrReturn(mCommandSender.mInvokeRequestBuilder.GetInvokeRequests().GetError() == CHIP_NO_ERROR); + VerifyOrReturn(mCommandSender.mInvokeRequestBuilder.GetError() == CHIP_NO_ERROR); + mCommandSender.mInvokeRequestBuilder.Checkpoint(mBackupWriter); + mBackupState = mCommandSender.mState; + mRollbackInDestructor = true; +} + +CommandSender::RollbackInvokeRequest::~RollbackInvokeRequest() +{ + VerifyOrReturn(mRollbackInDestructor); + VerifyOrReturn(mCommandSender.mState == State::AddingCommand); + ChipLogDetail(DataManagement, "Rolling back response"); + // TODO(#30453): Rollback of mInvokeRequestBuilder should handle resetting + // InvokeRequests. + mCommandSender.mInvokeRequestBuilder.GetInvokeRequests().ResetError(); + mCommandSender.mInvokeRequestBuilder.Rollback(mBackupWriter); + mCommandSender.MoveToState(mBackupState); + mRollbackInDestructor = false; +} + +void CommandSender::RollbackInvokeRequest::DisableAutomaticRollback() +{ + mRollbackInDestructor = false; +} + } // namespace app } // namespace chip diff --git a/src/app/CommandSender.h b/src/app/CommandSender.h index 5542d2469bb743..5ea890a2ebf163 100644 --- a/src/app/CommandSender.h +++ b/src/app/CommandSender.h @@ -216,6 +216,12 @@ class CommandSender final : public Messaging::ExchangeDelegate AddRequestDataParameters(const Optional & aTimedInvokeTimeoutMs) : timedInvokeTimeoutMs(aTimedInvokeTimeoutMs) {} + AddRequestDataParameters & SetCommandRef(uint16_t aCommandRef) + { + commandRef.SetValue(aCommandRef); + return *this; + } + // When a value is provided for timedInvokeTimeoutMs, this invoke becomes a timed // invoke. CommandSender will use the minimum of all provided timeouts for execution. const Optional timedInvokeTimeoutMs; @@ -511,6 +517,34 @@ class CommandSender final : public Messaging::ExchangeDelegate AwaitingDestruction, ///< The object has completed its work and is awaiting destruction by the application. }; + /** + * Class to help backup CommandSender's buffer containing InvokeRequestMessage when adding InvokeRequest + * in case there is a failure to add InvokeRequest. Intended usage is as follows: + * - Allocate RollbackInvokeRequest on the stack. + * - Attempt adding InvokeRequest into InvokeRequestMessage buffer. + * - If modification is added successfully, call DisableAutomaticRollback() to prevent destructor from + * rolling back InvokeReqestMessage. + * - If there is an issue adding InvokeRequest, destructor will take care of rolling back + * InvokeRequestMessage to previously saved state. + */ + class RollbackInvokeRequest + { + public: + explicit RollbackInvokeRequest(CommandSender & aCommandSender); + ~RollbackInvokeRequest(); + + /** + * Disables rolling back to previously saved state for InvokeRequestMessage. + */ + void DisableAutomaticRollback(); + + private: + CommandSender & mCommandSender; + TLV::TLVWriter mBackupWriter; + State mBackupState; + bool mRollbackInDestructor = false; + }; + union CallbackHandle { CallbackHandle(Callback * apCallback) : legacyCallback(apCallback) {} diff --git a/src/app/tests/TestCommandInteraction.cpp b/src/app/tests/TestCommandInteraction.cpp index 22c7ccc1a9e9b4..46f52d106cd3f1 100644 --- a/src/app/tests/TestCommandInteraction.cpp +++ b/src/app/tests/TestCommandInteraction.cpp @@ -109,10 +109,9 @@ enum class ForcedSizeBufferLengthHint kSizeGreaterThan255, }; -struct ForcedSizeBuffer +class ForcedSizeBuffer : public app::DataModel::EncodableToTLV { - chip::Platform::ScopedMemoryBufferWithSize mBuffer; - +public: ForcedSizeBuffer(uint32_t size) { if (mBuffer.Alloc(size)) @@ -124,7 +123,7 @@ struct ForcedSizeBuffer // No significance with using 0x12 as the CommandId, just using a value. static constexpr chip::CommandId GetCommandId() { return 0x12; } - CHIP_ERROR Encode(TLV::TLVWriter & aWriter, TLV::Tag aTag) const + CHIP_ERROR EncodeTo(TLV::TLVWriter & aWriter, TLV::Tag aTag) const override { VerifyOrReturnError(mBuffer, CHIP_ERROR_NO_MEMORY); @@ -133,6 +132,9 @@ struct ForcedSizeBuffer ReturnErrorOnFailure(app::DataModel::Encode(aWriter, TLV::ContextTag(1), ByteSpan(mBuffer.Get(), mBuffer.AllocatedSize()))); return aWriter.EndContainer(outerContainerType); } + +private: + chip::Platform::ScopedMemoryBufferWithSize mBuffer; }; struct Fields @@ -385,6 +387,7 @@ class TestCommandInteraction : public ::testing::Test void TestCommandSender_WithProcessReceivedMsg(); void TestCommandSender_ExtendableApiWithProcessReceivedMsg(); void TestCommandSender_ExtendableApiWithProcessReceivedMsgContainingInvalidCommandRef(); + void TestCommandSender_ValidateSecondLargeAddRequestDataRollbacked(); void TestCommandHandler_WithoutResponderCallingAddStatus(); void TestCommandHandler_WithoutResponderCallingAddResponse(); void TestCommandHandler_WithoutResponderCallingDirectPrepareFinishCommandApis(); @@ -630,7 +633,8 @@ uint32_t TestCommandInteraction::GetAddResponseDataOverheadSizeForPath(const Con // When ForcedSizeBuffer exceeds 255, an extra byte is needed for length, affecting the overhead size required by // AddResponseData. In order to have this accounted for in overhead calculation we set the length to be 256. uint32_t sizeOfForcedSizeBuffer = aBufferSizeHint == ForcedSizeBufferLengthHint::kSizeGreaterThan255 ? 256 : 0; - EXPECT_EQ(commandHandler.AddResponseData(aRequestCommandPath, ForcedSizeBuffer(sizeOfForcedSizeBuffer)), CHIP_NO_ERROR); + ForcedSizeBuffer responseData(sizeOfForcedSizeBuffer); + EXPECT_EQ(commandHandler.AddResponseData(aRequestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR); uint32_t remainingSizeAfter = commandHandler.mInvokeResponseBuilder.GetWriter()->GetRemainingFreeLength(); uint32_t delta = remainingSizeBefore - remainingSizeAfter - sizeOfForcedSizeBuffer; @@ -655,7 +659,8 @@ void TestCommandInteraction::FillCurrentInvokeResponseBuffer(CommandHandlerImpl // Validating assumption. If this fails, it means overheadSizeNeededForAddingResponse is likely too large. EXPECT_GE(sizeToFill, 256u); - EXPECT_EQ(apCommandHandler->AddResponseData(aRequestCommandPath, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR); + ForcedSizeBuffer responseData(sizeToFill); + EXPECT_EQ(apCommandHandler->AddResponseData(aRequestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR); } void TestCommandInteraction::ValidateCommandHandlerEncodeInvokeResponseMessage(bool aNeedStatusCode) @@ -1085,6 +1090,47 @@ TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandSender_ExtendableApiWithP EXPECT_EQ(mockCommandSenderExtendedDelegate.onErrorCalledTimes, 0); } +TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandSender_ValidateSecondLargeAddRequestDataRollbacked) +{ + mockCommandSenderExtendedDelegate.ResetCounter(); + PendingResponseTrackerImpl pendingResponseTracker; + app::CommandSender commandSender(kCommandSenderTestOnlyMarker, &mockCommandSenderExtendedDelegate, + &mpTestContext->GetExchangeManager(), &pendingResponseTracker); + + app::CommandSender::AddRequestDataParameters addRequestDataParams; + + CommandSender::ConfigParameters config; + config.SetRemoteMaxPathsPerInvoke(2); + EXPECT_EQ(commandSender.SetCommandSenderConfig(config), CHIP_NO_ERROR); + + // The specific values chosen here are arbitrary. + uint16_t firstCommandRef = 1; + uint16_t secondCommandRef = 2; + auto commandPathParams = MakeTestCommandPath(); + SimpleTLVPayload simplePayloadWriter; + addRequestDataParams.SetCommandRef(firstCommandRef); + + EXPECT_EQ(commandSender.AddRequestData(commandPathParams, simplePayloadWriter, addRequestDataParams), CHIP_NO_ERROR); + + uint32_t remainingSize = commandSender.mInvokeRequestBuilder.GetWriter()->GetRemainingFreeLength(); + // Because request is made of both request data and request path (commandPathParams), using + // `remainingSize` is large enough fail. + ForcedSizeBuffer requestData(remainingSize); + + addRequestDataParams.SetCommandRef(secondCommandRef); + EXPECT_EQ(commandSender.AddRequestData(commandPathParams, requestData, addRequestDataParams), CHIP_ERROR_NO_MEMORY); + + // Confirm that we can still send out a request with the first command. + EXPECT_EQ(commandSender.SendCommandRequest(mpTestContext->GetSessionBobToAlice()), CHIP_NO_ERROR); + EXPECT_EQ(commandSender.GetInvokeResponseMessageCount(), 0u); + + mpTestContext->DrainAndServiceIO(); + + EXPECT_EQ(mockCommandSenderExtendedDelegate.onResponseCalledTimes, 1); + EXPECT_EQ(mockCommandSenderExtendedDelegate.onFinalCalledTimes, 1); + EXPECT_EQ(mockCommandSenderExtendedDelegate.onErrorCalledTimes, 0); +} + TEST_F(TestCommandInteraction, TestCommandHandlerEncodeSimpleCommandData) { // Send response which has simple command data and command path @@ -1186,7 +1232,8 @@ TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandHandler_WithoutResponderC CommandHandlerImpl commandHandler(&mockCommandHandlerDelegate); uint32_t sizeToFill = 50; // This is an arbitrary number, we need to select a non-zero value. - EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR); + ForcedSizeBuffer responseData(sizeToFill); + EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR); // Since calling AddResponseData is supposed to be a no-operation when there is no responder, it is // hard to validate. Best way is to check that we are still in an Idle state afterwards @@ -1811,7 +1858,8 @@ TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandHandler_FillUpInvokeRespo EXPECT_EQ(remainingSize, sizeToLeave); uint32_t sizeToFill = 50; - EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath2, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR); + ForcedSizeBuffer responseData(sizeToFill); + EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath2, responseData.GetCommandId(), responseData), CHIP_NO_ERROR); remainingSize = commandHandler.mInvokeResponseBuilder.GetWriter()->GetRemainingFreeLength(); EXPECT_GT(remainingSize, sizeToLeave);