From 2c45ebea83426850a9b6f05fd5bc82723cccbbca Mon Sep 17 00:00:00 2001 From: Andrei Litvin Date: Mon, 15 Apr 2024 17:18:53 -0400 Subject: [PATCH] Use RAII for group session iteration (#32970) * Use RAII for iterator management * Restyle * Fix typo * Fix naming a bit now that I made this a template * Make it clear that class member is initialized --- src/transport/SessionManager.cpp | 35 ++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 78f9d98d79600f..d9ab8de48e1133 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -58,6 +58,31 @@ using Transport::SecureSession; namespace { Global gGroupPeerTable; +/// RAII class for iterators that guarantees that Release() will be called +/// on the underlying type +template +class AutoRelease +{ +public: + AutoRelease(Releasable * iter) : mIter(iter) {} + ~AutoRelease() { Release(); } + + Releasable * operator->() { return mIter; } + const Releasable * operator->() const { return mIter; } + + bool IsNull() const { return mIter == nullptr; } + + void Release() + { + VerifyOrReturn(mIter != nullptr); + mIter->Release(); + mIter = nullptr; + } + +private: + Releasable * mIter = nullptr; +}; + // Helper function that strips off the interface ID from a peer address that is // not an IPv6 link-local address. For any other address type we should rely on // the device's routing table to route messages sent. Forcing messages down a @@ -883,8 +908,11 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & partialPack // Trial decryption with GroupDataProvider Credentials::GroupDataProvider::GroupSession groupContext; - auto iter = groups->IterateGroupSessions(partialPacketHeader.GetSessionId()); - if (iter == nullptr) + + AutoRelease iter( + groups->IterateGroupSessions(partialPacketHeader.GetSessionId())); + + if (iter.IsNull()) { ChipLogError(Inet, "Failed to retrieve Groups iterator. Discarding everything"); return; @@ -931,7 +959,7 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & partialPack } #endif // CHIP_CONFIG_PRIVACY_ACCEPT_NONSPEC_SVE2 } - iter->Release(); + iter.Release(); if (!decrypted) { @@ -969,7 +997,6 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & partialPack gGroupPeerTable->FindOrAddPeer(groupContext.fabric_index, packetHeaderCopy.GetSourceNodeId().Value(), packetHeaderCopy.IsSecureSessionControlMsg(), counter)) { - if (Credentials::GroupDataProvider::SecurityPolicy::kTrustFirst == groupContext.security_policy) { err = counter->VerifyOrTrustFirstGroup(packetHeaderCopy.GetMessageCounter());