diff --git a/examples/network-manager-app/linux/args.gni b/examples/network-manager-app/linux/args.gni index 53bb53d2b1b2a1..e97ddb13e7c46e 100644 --- a/examples/network-manager-app/linux/args.gni +++ b/examples/network-manager-app/linux/args.gni @@ -22,3 +22,6 @@ chip_project_config_include_dirs = [ ] chip_config_network_layer_ble = false + +# This enables AccessRestrictionList (ARL) support used by the NIM sample app +chip_enable_access_restrictions = true diff --git a/examples/network-manager-app/network-manager-common/network-manager-app.matter b/examples/network-manager-app/network-manager-common/network-manager-app.matter index 627838288db0d1..57118d365577f1 100644 --- a/examples/network-manager-app/network-manager-common/network-manager-app.matter +++ b/examples/network-manager-app/network-manager-common/network-manager-app.matter @@ -1623,16 +1623,23 @@ endpoint 0 { server cluster AccessControl { emits event AccessControlEntryChanged; emits event AccessControlExtensionChanged; + emits event AccessRestrictionEntryChanged; + emits event FabricRestrictionReviewUpdate; callback attribute acl; callback attribute extension; callback attribute subjectsPerAccessControlEntry; callback attribute targetsPerAccessControlEntry; callback attribute accessControlEntriesPerFabric; + callback attribute commissioningARL; + callback attribute arl; callback attribute generatedCommandList; callback attribute acceptedCommandList; callback attribute attributeList; - ram attribute featureMap default = 0; + ram attribute featureMap default = 1; callback attribute clusterRevision; + + handle command ReviewFabricRestrictions; + handle command ReviewFabricRestrictionsResponse; } server cluster BasicInformation { diff --git a/examples/network-manager-app/network-manager-common/network-manager-app.zap b/examples/network-manager-app/network-manager-common/network-manager-app.zap index 1d27e3c346f320..240cd495a7fb4f 100644 --- a/examples/network-manager-app/network-manager-common/network-manager-app.zap +++ b/examples/network-manager-app/network-manager-common/network-manager-app.zap @@ -314,6 +314,24 @@ "define": "ACCESS_CONTROL_CLUSTER", "side": "server", "enabled": 1, + "commands": [ + { + "name": "ReviewFabricRestrictions", + "code": 0, + "mfgCode": null, + "source": "client", + "isIncoming": 1, + "isEnabled": 1 + }, + { + "name": "ReviewFabricRestrictionsResponse", + "code": 1, + "mfgCode": null, + "source": "server", + "isIncoming": 0, + "isEnabled": 1 + } + ], "attributes": [ { "name": "ACL", @@ -395,6 +413,38 @@ "maxInterval": 65534, "reportableChange": 0 }, + { + "name": "CommissioningARL", + "code": 5, + "mfgCode": null, + "side": "server", + "type": "array", + "included": 1, + "storageOption": "External", + "singleton": 0, + "bounded": 0, + "defaultValue": null, + "reportable": 1, + "minInterval": 1, + "maxInterval": 65534, + "reportableChange": 0 + }, + { + "name": "ARL", + "code": 6, + "mfgCode": null, + "side": "server", + "type": "array", + "included": 1, + "storageOption": "External", + "singleton": 0, + "bounded": 0, + "defaultValue": "", + "reportable": 1, + "minInterval": 1, + "maxInterval": 65534, + "reportableChange": 0 + }, { "name": "GeneratedCommandList", "code": 65528, @@ -453,7 +503,7 @@ "storageOption": "RAM", "singleton": 0, "bounded": 0, - "defaultValue": "0", + "defaultValue": "1", "reportable": 1, "minInterval": 1, "maxInterval": 65534, @@ -490,6 +540,20 @@ "mfgCode": null, "side": "server", "included": 1 + }, + { + "name": "AccessRestrictionEntryChanged", + "code": 2, + "mfgCode": null, + "side": "server", + "included": 1 + }, + { + "name": "FabricRestrictionReviewUpdate", + "code": 3, + "mfgCode": null, + "side": "server", + "included": 1 } ] }, diff --git a/examples/platform/linux/AppMain.cpp b/examples/platform/linux/AppMain.cpp index 307b3428126db2..074f078af003b4 100644 --- a/examples/platform/linux/AppMain.cpp +++ b/examples/platform/linux/AppMain.cpp @@ -103,6 +103,10 @@ #include "AppMain.h" #include "CommissionableInit.h" +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS +#include "ExampleAccessRestrictionProvider.h" +#endif + #if CHIP_DEVICE_LAYER_TARGET_DARWIN #include #if CHIP_DEVICE_CONFIG_ENABLE_WIFI @@ -121,6 +125,7 @@ using namespace chip::DeviceLayer; using namespace chip::Inet; using namespace chip::Transport; using namespace chip::app::Clusters; +using namespace chip::Access; // Network comissioning implementation namespace { @@ -180,6 +185,10 @@ Optional sWiFiNetworkCommissionin app::Clusters::NetworkCommissioning::Instance sEthernetNetworkCommissioningInstance(kRootEndpointId, &sEthernetDriver); #endif // CHIP_APP_MAIN_HAS_ETHERNET_DRIVER +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS +auto exampleAccessRestrictionProvider = std::make_unique(); +#endif + void EnableThreadNetworkCommissioning() { #if CHIP_APP_MAIN_HAS_THREAD_DRIVER @@ -593,9 +602,27 @@ void ChipLinuxAppMainLoop(AppMainLoopImplementation * impl) chip::app::RuntimeOptionsProvider::Instance().SetSimulateNoInternalTime( LinuxDeviceOptions::GetInstance().mSimulateNoInternalTime); +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + initParams.accessRestrictionProvider = exampleAccessRestrictionProvider.get(); +#endif + // Init ZCL Data Model and CHIP App Server Server::GetInstance().Init(initParams); +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + if (LinuxDeviceOptions::GetInstance().commissioningArlEntries.HasValue()) + { + exampleAccessRestrictionProvider->SetCommissioningEntries( + LinuxDeviceOptions::GetInstance().commissioningArlEntries.Value()); + } + + if (LinuxDeviceOptions::GetInstance().arlEntries.HasValue()) + { + // This example use of the ARL feature proactively installs the provided entries on fabric index 1 + exampleAccessRestrictionProvider->SetEntries(1, LinuxDeviceOptions::GetInstance().arlEntries.Value()); + } +#endif + #if CONFIG_BUILD_FOR_HOST_UNIT_TEST // Set ReadHandler Capacity for Subscriptions chip::app::InteractionModelEngine::GetInstance()->SetHandlerCapacityForSubscriptions( diff --git a/examples/platform/linux/BUILD.gn b/examples/platform/linux/BUILD.gn index 1fcee183f131b3..336534aae65043 100644 --- a/examples/platform/linux/BUILD.gn +++ b/examples/platform/linux/BUILD.gn @@ -19,6 +19,10 @@ import("${chip_root}/src/lib/core/core.gni") import("${chip_root}/src/lib/lib.gni") import("${chip_root}/src/tracing/tracing_args.gni") +if (current_os != "nuttx") { + import("//build_overrides/jsoncpp.gni") +} + declare_args() { chip_enable_smoke_co_trigger = false chip_enable_boolean_state_configuration_trigger = false @@ -101,6 +105,10 @@ source_set("app-main") { "${chip_root}/src/app/server", ] + if (current_os != "nuttx") { + public_deps += [ jsoncpp_root ] + } + if (chip_enable_pw_rpc) { defines += [ "PW_RPC_ENABLED" ] } diff --git a/examples/platform/linux/ExampleAccessRestrictionProvider.h b/examples/platform/linux/ExampleAccessRestrictionProvider.h new file mode 100644 index 00000000000000..731a8ae5845a21 --- /dev/null +++ b/examples/platform/linux/ExampleAccessRestrictionProvider.h @@ -0,0 +1,55 @@ +/* + * + * Copyright (c) 2024 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * AccessRestriction implementation for Linux examples. + */ + +#pragma once + +#include +#include +#include + +namespace chip { +namespace Access { + +class ExampleAccessRestrictionProvider : public AccessRestrictionProvider +{ +public: + ExampleAccessRestrictionProvider() : AccessRestrictionProvider() {} + + ~ExampleAccessRestrictionProvider() {} + +protected: + CHIP_ERROR DoRequestFabricRestrictionReview(const FabricIndex fabricIndex, uint64_t token, const std::vector & arl) + { + // this example simply removes all restrictions and will generate AccessRestrictionEntryChanged events + Access::GetAccessControl().GetAccessRestrictionProvider()->SetEntries(fabricIndex, std::vector{}); + + chip::app::Clusters::AccessControl::Events::FabricRestrictionReviewUpdate::Type event{ .token = token, + .fabricIndex = fabricIndex }; + EventNumber eventNumber; + ReturnErrorOnFailure(chip::app::LogEvent(event, kRootEndpointId, eventNumber)); + + return CHIP_NO_ERROR; + } +}; + +} // namespace Access +} // namespace chip diff --git a/examples/platform/linux/Options.cpp b/examples/platform/linux/Options.cpp index 9b83d126c1f495..6f8afea5bb496b 100644 --- a/examples/platform/linux/Options.cpp +++ b/examples/platform/linux/Options.cpp @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -47,6 +48,11 @@ using namespace chip; using namespace chip::ArgParser; +using namespace chip::Platform; + +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS +using namespace chip::Access; +#endif namespace { LinuxDeviceOptions gDeviceOptions; @@ -82,6 +88,10 @@ enum kDeviceOption_TraceFile, kDeviceOption_TraceLog, kDeviceOption_TraceDecode, +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + kDeviceOption_CommissioningArlEntries, + kDeviceOption_ArlEntries, +#endif kOptionCSRResponseCSRIncorrectType, kOptionCSRResponseCSRNonceIncorrectType, kOptionCSRResponseCSRNonceTooLong, @@ -154,6 +164,10 @@ OptionDef sDeviceOptionDefs[] = { { "trace_log", kArgumentRequired, kDeviceOption_TraceLog }, { "trace_decode", kArgumentRequired, kDeviceOption_TraceDecode }, #endif // CHIP_CONFIG_TRANSPORT_TRACE_ENABLED +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + { "commissioning-arl-entries", kArgumentRequired, kDeviceOption_CommissioningArlEntries }, + { "arl-entries", kArgumentRequired, kDeviceOption_ArlEntries }, +#endif // CHIP_CONFIG_USE_ACCESS_RESTRICTIONS { "cert_error_csr_incorrect_type", kNoArgument, kOptionCSRResponseCSRIncorrectType }, { "cert_error_csr_existing_keypair", kNoArgument, kOptionCSRResponseCSRExistingKeyPair }, { "cert_error_csr_nonce_incorrect_type", kNoArgument, kOptionCSRResponseCSRNonceIncorrectType }, @@ -280,6 +294,14 @@ const char * sDeviceOptionHelp = " --trace_decode <1/0>\n" " A value of 1 enables traces decoding, 0 disables this (default 0).\n" #endif // CHIP_CONFIG_TRANSPORT_TRACE_ENABLED +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + " --commissioning-arl-entries \n" + " Enable ACL cluster access restrictions used during commissioning with the provided JSON. Example:\n" + " \"[{\\\"endpoint\\\": 1,\\\"cluster\\\": 1105,\\\"restrictions\\\": [{\\\"type\\\": 0,\\\"id\\\": 0}]}]\"\n" + " --arl-entries \n" + " Enable ACL cluster access restrictions applied to fabric index 1 with the provided JSON. Example:\n" + " \"[{\\\"endpoint\\\": 1,\\\"cluster\\\": 1105,\\\"restrictions\\\": [{\\\"type\\\": 0,\\\"id\\\": 0}]}]\"\n" +#endif // CHIP_CONFIG_USE_ACCESS_RESTRICTIONS " --cert_error_csr_incorrect_type\n" " Configure the CSRResponse to be built with an invalid CSR type.\n" " --cert_error_csr_existing_keypair\n" @@ -320,6 +342,39 @@ const char * sDeviceOptionHelp = #endif "\n"; +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS +bool ParseAccessRestrictionEntriesFromJson(const char * jsonString, std::vector & entries) +{ + Json::Value root; + Json::Reader reader; + VerifyOrReturnValue(reader.parse(jsonString, root), false); + + for (Json::Value::const_iterator eIt = root.begin(); eIt != root.end(); eIt++) + { + AccessRestrictionProvider::Entry entry; + + entry.endpointNumber = static_cast((*eIt)["endpoint"].asUInt()); + entry.clusterId = static_cast((*eIt)["cluster"].asUInt()); + + Json::Value restrictions = (*eIt)["restrictions"]; + for (Json::Value::const_iterator rIt = restrictions.begin(); rIt != restrictions.end(); rIt++) + { + AccessRestrictionProvider::Restriction restriction; + restriction.restrictionType = static_cast((*rIt)["type"].asUInt()); + if ((*rIt).isMember("id")) + { + restriction.id.SetValue((*rIt)["id"].asUInt()); + } + entry.restrictions.push_back(restriction); + } + + entries.push_back(entry); + } + + return true; +} +#endif // CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + bool Base64ArgToVector(const char * arg, size_t maxSize, std::vector & outVector) { size_t maxBase64Size = BASE64_ENCODED_LEN(maxSize); @@ -529,6 +584,28 @@ bool HandleOption(const char * aProgram, OptionSet * aOptions, int aIdentifier, break; #endif // CHIP_CONFIG_TRANSPORT_TRACE_ENABLED +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + // TODO(#35189): change to use a path to JSON files instead + case kDeviceOption_CommissioningArlEntries: { + std::vector entries; + retval = ParseAccessRestrictionEntriesFromJson(aValue, entries); + if (retval) + { + LinuxDeviceOptions::GetInstance().commissioningArlEntries.SetValue(std::move(entries)); + } + } + break; + case kDeviceOption_ArlEntries: { + std::vector entries; + retval = ParseAccessRestrictionEntriesFromJson(aValue, entries); + if (retval) + { + LinuxDeviceOptions::GetInstance().arlEntries.SetValue(std::move(entries)); + } + } + break; +#endif // CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + case kOptionCSRResponseCSRIncorrectType: LinuxDeviceOptions::GetInstance().mCSRResponseOptions.csrIncorrectType = true; break; diff --git a/examples/platform/linux/Options.h b/examples/platform/linux/Options.h index f921bee4ced554..11a9061efcade8 100644 --- a/examples/platform/linux/Options.h +++ b/examples/platform/linux/Options.h @@ -28,6 +28,7 @@ #include #include +#include #include #include #include @@ -38,6 +39,10 @@ #include #include +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS +#include +#endif + struct LinuxDeviceOptions { chip::PayloadContents payload; @@ -81,6 +86,10 @@ struct LinuxDeviceOptions #if CONFIG_BUILD_FOR_HOST_UNIT_TEST int32_t subscriptionCapacity = CHIP_IM_MAX_NUM_SUBSCRIPTIONS; int32_t subscriptionResumptionRetryIntervalSec = -1; +#endif +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + chip::Optional> commissioningArlEntries; + chip::Optional> arlEntries; #endif static LinuxDeviceOptions & GetInstance(); }; diff --git a/scripts/tools/check_includes_config.py b/scripts/tools/check_includes_config.py index 2af375d7c4bdfd..b5195f4ab05eda 100644 --- a/scripts/tools/check_includes_config.py +++ b/scripts/tools/check_includes_config.py @@ -185,4 +185,5 @@ 'src/app/icd/client/DefaultICDStorageKey.h': {'vector'}, 'src/controller/CHIPDeviceController.cpp': {'string'}, 'src/qrcodetool/setup_payload_commands.cpp': {'string'}, + 'src/access/AccessRestrictionProvider.h': {'vector', 'map'}, } diff --git a/src/access/AccessConfig.h b/src/access/AccessConfig.h new file mode 100644 index 00000000000000..b9318a10d9e4d5 --- /dev/null +++ b/src/access/AccessConfig.h @@ -0,0 +1,22 @@ +/* + * + * Copyright (c) 2024 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#if CHIP_HAVE_CONFIG_H +#include +#endif diff --git a/src/access/AccessControl.cpp b/src/access/AccessControl.cpp index fcb5a43d975f8e..8302fb0b122265 100644 --- a/src/access/AccessControl.cpp +++ b/src/access/AccessControl.cpp @@ -181,7 +181,7 @@ char GetRequestTypeStringForLogging(RequestType requestType) return 'w'; case RequestType::kCommandInvokeRequest: return 'i'; - case RequestType::kEventReadOrSubscribeRequest: + case RequestType::kEventReadRequest: return 'e'; default: return '?'; @@ -325,7 +325,11 @@ void AccessControl::RemoveEntryListener(EntryListener & listener) bool AccessControl::IsAccessRestrictionListSupported() const { - return false; // not yet supported +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + return mAccessRestrictionProvider != nullptr; +#else + return false; +#endif } CHIP_ERROR AccessControl::Check(const SubjectDescriptor & subjectDescriptor, const RequestPath & requestPath, @@ -333,6 +337,21 @@ CHIP_ERROR AccessControl::Check(const SubjectDescriptor & subjectDescriptor, con { VerifyOrReturnError(IsInitialized(), CHIP_ERROR_INCORRECT_STATE); + CHIP_ERROR result = CheckACL(subjectDescriptor, requestPath, requestPrivilege); + +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + if (result == CHIP_NO_ERROR) + { + result = CheckARL(subjectDescriptor, requestPath, requestPrivilege); + } +#endif + + return result; +} + +CHIP_ERROR AccessControl::CheckACL(const SubjectDescriptor & subjectDescriptor, const RequestPath & requestPath, + Privilege requestPrivilege) +{ #if CHIP_PROGRESS_LOGGING && CHIP_CONFIG_ACCESS_CONTROL_POLICY_LOGGING_VERBOSITY > 1 { constexpr size_t kMaxCatsToLog = 6; @@ -347,11 +366,6 @@ CHIP_ERROR AccessControl::Check(const SubjectDescriptor & subjectDescriptor, con } #endif // CHIP_PROGRESS_LOGGING && CHIP_CONFIG_ACCESS_CONTROL_POLICY_LOGGING_VERBOSITY > 1 - if (IsAccessRestrictionListSupported()) - { - VerifyOrReturnError(requestPath.requestType != RequestType::kRequestTypeUnknown, CHIP_ERROR_INVALID_ARGUMENT); - } - { CHIP_ERROR result = mDelegate->Check(subjectDescriptor, requestPath, requestPrivilege); if (result != CHIP_ERROR_NOT_IMPLEMENTED) @@ -368,6 +382,7 @@ CHIP_ERROR AccessControl::Check(const SubjectDescriptor & subjectDescriptor, con (result == CHIP_ERROR_ACCESS_DENIED) ? "denied" : "error"); } #endif // CHIP_CONFIG_ACCESS_CONTROL_POLICY_LOGGING_VERBOSITY > 0 + return result; } } @@ -497,6 +512,45 @@ CHIP_ERROR AccessControl::Check(const SubjectDescriptor & subjectDescriptor, con return CHIP_ERROR_ACCESS_DENIED; } +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS +CHIP_ERROR AccessControl::CheckARL(const SubjectDescriptor & subjectDescriptor, const RequestPath & requestPath, + Privilege requestPrivilege) +{ + CHIP_ERROR result = CHIP_NO_ERROR; + + VerifyOrReturnError(requestPath.requestType != RequestType::kRequestTypeUnknown, CHIP_ERROR_INVALID_ARGUMENT); + + if (!IsAccessRestrictionListSupported()) + { + // Access Restriction support is compiled in, but not configured/enabled. Nothing to restrict. + return CHIP_NO_ERROR; + } + + if (subjectDescriptor.isCommissioning) + { + result = mAccessRestrictionProvider->CheckForCommissioning(subjectDescriptor, requestPath); + } + else + { + result = mAccessRestrictionProvider->Check(subjectDescriptor, requestPath); + } + + if (result != CHIP_NO_ERROR) + { + ChipLogProgress(DataManagement, "AccessControl: %s", +#if 0 + // TODO(#35177): new error code coming when access check plumbing are fixed in callers + (result == CHIP_ERROR_ACCESS_RESTRICTED_BY_ARL) ? "denied (restricted)" : "denied (restriction error)"); +#else + (result == CHIP_ERROR_ACCESS_DENIED) ? "denied (restricted)" : "denied (restriction error)"); +#endif + return result; + } + + return result; +} +#endif + #if CHIP_ACCESS_CONTROL_DUMP_ENABLED CHIP_ERROR AccessControl::Dump(const Entry & entry) { diff --git a/src/access/AccessControl.h b/src/access/AccessControl.h index a7c3472f5d99b4..df986864b8ff4c 100644 --- a/src/access/AccessControl.h +++ b/src/access/AccessControl.h @@ -18,6 +18,12 @@ #pragma once +#include + +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS +#include "AccessRestrictionProvider.h" +#endif + #include "Privilege.h" #include "RequestPath.h" #include "SubjectDescriptor.h" @@ -627,6 +633,16 @@ class AccessControl // Removes a listener from the listener list, if in the list. void RemoveEntryListener(EntryListener & listener); +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + // Set an optional AcceessRestriction object for MNGD feature. + void SetAccessRestrictionProvider(AccessRestrictionProvider * accessRestrictionProvider) + { + mAccessRestrictionProvider = accessRestrictionProvider; + } + + AccessRestrictionProvider * GetAccessRestrictionProvider() { return mAccessRestrictionProvider; } +#endif + /** * Check whether or not Access Restriction List is supported. * @@ -638,6 +654,8 @@ class AccessControl * Check whether access (by a subject descriptor, to a request path, * requiring a privilege) should be allowed or denied. * + * If an AccessRestrictionProvider object is set, it will be checked for additional access restrictions. + * * @retval #CHIP_ERROR_ACCESS_DENIED if denied. * @retval other errors should also be treated as denied. * @retval #CHIP_NO_ERROR if allowed. @@ -656,12 +674,29 @@ class AccessControl void NotifyEntryChanged(const SubjectDescriptor * subjectDescriptor, FabricIndex fabric, size_t index, const Entry * entry, EntryListener::ChangeType changeType); + /** + * Check ACL for whether access (by a subject descriptor, to a request path, + * requiring a privilege) should be allowed or denied. + */ + CHIP_ERROR CheckACL(const SubjectDescriptor & subjectDescriptor, const RequestPath & requestPath, Privilege requestPrivilege); + + /** + * Check CommissioningARL or ARL (as appropriate) for whether access (by a + * subject descriptor, to a request path, requiring a privilege) should + * be allowed or denied. + */ + CHIP_ERROR CheckARL(const SubjectDescriptor & subjectDescriptor, const RequestPath & requestPath, Privilege requestPrivilege); + private: Delegate * mDelegate = nullptr; DeviceTypeResolver * mDeviceTypeResolver = nullptr; EntryListener * mEntryListener = nullptr; + +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + AccessRestrictionProvider * mAccessRestrictionProvider; +#endif }; /** diff --git a/src/access/AccessRestrictionProvider.cpp b/src/access/AccessRestrictionProvider.cpp new file mode 100644 index 00000000000000..23e8082353abc8 --- /dev/null +++ b/src/access/AccessRestrictionProvider.cpp @@ -0,0 +1,249 @@ +/* + * + * Copyright (c) 2024 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "AccessRestrictionProvider.h" + +#include +#include + +using namespace chip::Platform; + +namespace chip { +namespace Access { + +void AccessRestrictionProvider::AddListener(Listener & listener) +{ + if (mListeners == nullptr) + { + mListeners = &listener; + listener.mNext = nullptr; + return; + } + + for (Listener * l = mListeners; /**/; l = l->mNext) + { + if (l == &listener) + { + return; + } + + if (l->mNext == nullptr) + { + l->mNext = &listener; + listener.mNext = nullptr; + return; + } + } +} + +void AccessRestrictionProvider::RemoveListener(Listener & listener) +{ + if (mListeners == &listener) + { + mListeners = listener.mNext; + listener.mNext = nullptr; + return; + } + + for (Listener * l = mListeners; l != nullptr; l = l->mNext) + { + if (l->mNext == &listener) + { + l->mNext = listener.mNext; + listener.mNext = nullptr; + return; + } + } +} + +CHIP_ERROR AccessRestrictionProvider::SetCommissioningEntries(const std::vector & entries) +{ + for (auto & entry : entries) + { + if (!mExceptionChecker.AreRestrictionsAllowed(entry.endpointNumber, entry.clusterId)) + { + ChipLogError(DataManagement, "AccessRestrictionProvider: invalid entry"); + return CHIP_ERROR_INVALID_ARGUMENT; + } + } + + mCommissioningEntries = entries; + + for (Listener * listener = mListeners; listener != nullptr; listener = listener->mNext) + { + listener->MarkCommissioningRestrictionListChanged(); + } + + return CHIP_NO_ERROR; +} + +CHIP_ERROR AccessRestrictionProvider::SetEntries(const FabricIndex fabricIndex, const std::vector & entries) +{ + std::vector updatedEntries; + + for (auto & entry : entries) + { + if (!mExceptionChecker.AreRestrictionsAllowed(entry.endpointNumber, entry.clusterId)) + { + ChipLogError(DataManagement, "AccessRestrictionProvider: invalid entry"); + return CHIP_ERROR_INVALID_ARGUMENT; + } + + Entry updatedEntry = entry; + updatedEntry.fabricIndex = fabricIndex; + updatedEntries.push_back(updatedEntry); + } + + mFabricEntries[fabricIndex] = std::move(updatedEntries); + + for (Listener * listener = mListeners; listener != nullptr; listener = listener->mNext) + { + listener->MarkRestrictionListChanged(fabricIndex); + } + + return CHIP_NO_ERROR; +} + +bool AccessRestrictionProvider::StandardAccessRestrictionExceptionChecker::AreRestrictionsAllowed(EndpointId endpoint, + ClusterId cluster) +{ + if (endpoint != kRootEndpointId && + (cluster == app::Clusters::WiFiNetworkManagement::Id || cluster == app::Clusters::ThreadBorderRouterManagement::Id || + cluster == app::Clusters::ThreadNetworkDirectory::Id)) + { + return true; + } + + return false; +} + +CHIP_ERROR AccessRestrictionProvider::CheckForCommissioning(const SubjectDescriptor & subjectDescriptor, + const RequestPath & requestPath) +{ + return DoCheck(mCommissioningEntries, subjectDescriptor, requestPath); +} + +CHIP_ERROR AccessRestrictionProvider::Check(const SubjectDescriptor & subjectDescriptor, const RequestPath & requestPath) +{ + return DoCheck(mFabricEntries[subjectDescriptor.fabricIndex], subjectDescriptor, requestPath); +} + +CHIP_ERROR AccessRestrictionProvider::DoCheck(const std::vector & entries, const SubjectDescriptor & subjectDescriptor, + const RequestPath & requestPath) +{ + if (!mExceptionChecker.AreRestrictionsAllowed(requestPath.endpoint, requestPath.cluster)) + { + ChipLogProgress(DataManagement, "AccessRestrictionProvider: skipping checks for unrestrictable request path"); + return CHIP_NO_ERROR; + } + + ChipLogProgress(DataManagement, "AccessRestrictionProvider: action %d", to_underlying(requestPath.requestType)); + + if (requestPath.requestType == RequestType::kRequestTypeUnknown) + { + ChipLogError(DataManagement, "AccessRestrictionProvider: RequestPath type is unknown"); + return CHIP_ERROR_INVALID_ARGUMENT; + } + + // wildcard event subscriptions are allowed since wildcard is only used when setting up the subscription and + // we want that request to succeed (when generating the report, this method will be called with the specific + // event id). All other requests require an entity id + if (!requestPath.entityId.has_value()) + { + if (requestPath.requestType == RequestType::kEventReadRequest) + { + return CHIP_NO_ERROR; + } + else + { + return CHIP_ERROR_INVALID_ARGUMENT; + } + } + + for (auto & entry : entries) + { + if (entry.endpointNumber != requestPath.endpoint || entry.clusterId != requestPath.cluster) + { + continue; + } + + for (auto & restriction : entry.restrictions) + { + // A missing id is a wildcard + bool idMatch = !restriction.id.HasValue() || restriction.id.Value() == requestPath.entityId.value(); + if (!idMatch) + { + continue; + } + + switch (restriction.restrictionType) + { + case Type::kAttributeAccessForbidden: + if (requestPath.requestType == RequestType::kAttributeReadRequest || + requestPath.requestType == RequestType::kAttributeWriteRequest) + { +#if 0 + // TODO(#35177): use new ARL error code when access checks are fixed + return CHIP_ERROR_ACCESS_RESTRICTED_BY_ARL; +#else + return CHIP_ERROR_ACCESS_DENIED; +#endif + } + break; + case Type::kAttributeWriteForbidden: + if (requestPath.requestType == RequestType::kAttributeWriteRequest) + { +#if 0 + // TODO(#35177): use new ARL error code when access checks are fixed + return CHIP_ERROR_ACCESS_RESTRICTED_BY_ARL; +#else + return CHIP_ERROR_ACCESS_DENIED; +#endif + } + break; + case Type::kCommandForbidden: + if (requestPath.requestType == RequestType::kCommandInvokeRequest) + { +#if 0 + // TODO(#35177): use new ARL error code when access checks are fixed + return CHIP_ERROR_ACCESS_RESTRICTED_BY_ARL; +#else + return CHIP_ERROR_ACCESS_DENIED; +#endif + } + break; + case Type::kEventForbidden: + if (requestPath.requestType == RequestType::kEventReadRequest) + { +#if 0 + // TODO(#35177): use new ARL error code when access checks are fixed + return CHIP_ERROR_ACCESS_RESTRICTED_BY_ARL; +#else + return CHIP_ERROR_ACCESS_DENIED; +#endif + } + break; + } + } + } + + return CHIP_NO_ERROR; +} + +} // namespace Access +} // namespace chip diff --git a/src/access/AccessRestrictionProvider.h b/src/access/AccessRestrictionProvider.h new file mode 100644 index 00000000000000..705d9c365f8f1f --- /dev/null +++ b/src/access/AccessRestrictionProvider.h @@ -0,0 +1,275 @@ +/* + * + * Copyright (c) 2024 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "Privilege.h" +#include "RequestPath.h" +#include "SubjectDescriptor.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace chip { +namespace Access { + +class AccessRestrictionProvider +{ +public: + static constexpr size_t kNumberOfFabrics = CHIP_CONFIG_MAX_FABRICS; + static constexpr size_t kEntriesPerFabric = CHIP_CONFIG_ACCESS_RESTRICTION_MAX_ENTRIES_PER_FABRIC; + static constexpr size_t kRestrictionsPerEntry = CHIP_CONFIG_ACCESS_RESTRICTION_MAX_RESTRICTIONS_PER_ENTRY; + + /** + * Defines the type of access restriction, which is used to determine the meaning of the restriction's id. + */ + enum class Type : uint8_t + { + kAttributeAccessForbidden = 0, + kAttributeWriteForbidden = 1, + kCommandForbidden = 2, + kEventForbidden = 3 + }; + + /** + * Defines a single restriction on an attribute, command, or event. + * + * If id is not set, the restriction applies to all attributes, commands, or events of the given type (wildcard). + */ + struct Restriction + { + Type restrictionType; + Optional id; + }; + + /** + * Defines a single entry in the access restriction list, which contains a list of restrictions + * for a cluster on an endpoint. + */ + struct Entry + { + FabricIndex fabricIndex; + EndpointId endpointNumber; + ClusterId clusterId; + std::vector restrictions; + }; + + /** + * Defines the interface for a checker for access restriction exceptions. + */ + class AccessRestrictionExceptionChecker + { + public: + virtual ~AccessRestrictionExceptionChecker() = default; + + /** + * Check if any restrictions are allowed to be applied on the given endpoint and cluster + * because of constraints against their use in ARLs. + * + * @retval true if ARL checks are allowed to be applied to the cluster on the endpoint, false otherwise + */ + virtual bool AreRestrictionsAllowed(EndpointId endpoint, ClusterId cluster) = 0; + }; + + /** + * Define a standard implementation of the AccessRestrictionExceptionChecker interface + * which is the default implementation used by AccessResrictionProvider. + */ + class StandardAccessRestrictionExceptionChecker : public AccessRestrictionExceptionChecker + { + public: + StandardAccessRestrictionExceptionChecker() = default; + ~StandardAccessRestrictionExceptionChecker() = default; + + bool AreRestrictionsAllowed(EndpointId endpoint, ClusterId cluster) override; + }; + + /** + * Used to notify of changes in the access restriction list and active reviews. + */ + class Listener + { + public: + virtual ~Listener() = default; + + /** + * Notifies of a change in the commissioning access restriction list. + */ + virtual void MarkCommissioningRestrictionListChanged() = 0; + + /** + * Notifies of a change in the access restriction list. + * + * @param [in] fabricIndex The index of the fabric in which the list has changed. + */ + virtual void MarkRestrictionListChanged(FabricIndex fabricIndex) = 0; + + /** + * Notifies of an update to an active review with instructions and an optional redirect URL. + * + * @param [in] fabricIndex The index of the fabric in which the entry has changed. + * @param [in] token The token of the review being updated (obtained from ReviewFabricRestrictionsResponse) + * @param [in] instruction Optional instructions to be displayed to the user. + * @param [in] redirectUrl An optional URL to redirect the user to for more information. + */ + virtual void OnFabricRestrictionReviewUpdate(FabricIndex fabricIndex, uint64_t token, Optional instruction, + Optional redirectUrl) = 0; + + private: + Listener * mNext = nullptr; + + friend class AccessRestrictionProvider; + }; + + AccessRestrictionProvider() = default; + virtual ~AccessRestrictionProvider() = default; + + AccessRestrictionProvider(const AccessRestrictionProvider &) = delete; + AccessRestrictionProvider & operator=(const AccessRestrictionProvider &) = delete; + + /** + * Set the restriction entries that are to be used during commissioning when there is no accessing fabric. + * + * @param [in] entries The entries to set. + */ + CHIP_ERROR SetCommissioningEntries(const std::vector & entries); + + /** + * Set the restriction entries for a fabric. + * + * @param [in] fabricIndex The index of the fabric for which to create entries. + * @param [in] entries The entries to set for the fabric. + */ + CHIP_ERROR SetEntries(const FabricIndex, const std::vector & entries); + + /** + * Add a listener to be notified of changes in the access restriction list and active reviews. + * + * @param [in] listener The listener to add. + */ + void AddListener(Listener & listener); + + /** + * Remove a listener from being notified of changes in the access restriction list and active reviews. + * + * @param [in] listener The listener to remove. + */ + void RemoveListener(Listener & listener); + + /** + * Check whether access by a subject descriptor to a request path should be restricted (denied) for the given action + * during commissioning by using the CommissioningEntries. + * + * These restrictions are are only a part of overall access evaluation. + * + * If access is not restricted, CHIP_NO_ERROR will be returned. + * + * @retval CHIP_ERROR_ACCESS_DENIED if access is denied. + * @retval other errors should also be treated as restricted/denied. + * @retval CHIP_NO_ERROR if access is not restricted/denied. + */ + CHIP_ERROR CheckForCommissioning(const SubjectDescriptor & subjectDescriptor, const RequestPath & requestPath); + + /** + * Check whether access by a subject descriptor to a request path should be restricted (denied) for the given action. + * These restrictions are are only a part of overall access evaluation. + * + * If access is not restricted, CHIP_NO_ERROR will be returned. + * + * @retval CHIP_ERROR_ACCESS_DENIED if access is denied. + * @retval other errors should also be treated as restricted/denied. + * @retval CHIP_NO_ERROR if access is not restricted/denied. + */ + CHIP_ERROR Check(const SubjectDescriptor & subjectDescriptor, const RequestPath & requestPath); + + /** + * Request a review of the access restrictions for a fabric. + * + * @param [in] fabricIndex The index of the fabric requesting a review. + * @param [in] arl An optinal list of access restriction entries to review. If null, all entries will be reviewed. + * @param [out] token The unique token for the review, which can be matched to a review update event. + */ + CHIP_ERROR RequestFabricRestrictionReview(FabricIndex fabricIndex, const std::vector & arl, uint64_t & token) + { + token = mNextToken++; + return DoRequestFabricRestrictionReview(fabricIndex, token, arl); + } + + /** + * Get the commissioning restriction entries. + * + * @retval the commissioning restriction entries. + */ + const std::vector & GetCommissioningEntries() const { return mCommissioningEntries; } + + /** + * Get the restriction entries for a fabric. + * + * @param [in] fabricIndex the index of the fabric for which to get entries. + * @param [out] entries vector to hold the entries. + */ + CHIP_ERROR GetEntries(const FabricIndex fabricIndex, std::vector & entries) const + { + auto it = mFabricEntries.find(fabricIndex); + if (it == mFabricEntries.end()) + { + return CHIP_ERROR_NOT_FOUND; + } + + entries = (it->second); + + return CHIP_NO_ERROR; + } + +protected: + /** + * Initiate a review of the access restrictions for a fabric. This method should be implemented by the platform and be + * non-blocking. + * + * @param [in] fabricIndex The index of the fabric requesting a review. + * @param [in] token The unique token for the review, which can be matched to a review update event. + * @param [in] arl An optinal list of access restriction entries to review. If null, all entries will be reviewed. + * @return CHIP_NO_ERROR if the review was successfully requested, or an error code if the request failed. + */ + virtual CHIP_ERROR DoRequestFabricRestrictionReview(const FabricIndex fabricIndex, uint64_t token, + const std::vector & arl) = 0; + +private: + /** + * Perform the access restriction check using the given entries. + */ + CHIP_ERROR DoCheck(const std::vector & entries, const SubjectDescriptor & subjectDescriptor, + const RequestPath & requestPath); + + uint64_t mNextToken = 1; + Listener * mListeners = nullptr; + StandardAccessRestrictionExceptionChecker mExceptionChecker; + std::vector mCommissioningEntries; + std::map> mFabricEntries; +}; + +} // namespace Access +} // namespace chip diff --git a/src/access/BUILD.gn b/src/access/BUILD.gn index 8d2c4b504975ee..3f3aaafe8a6a9b 100644 --- a/src/access/BUILD.gn +++ b/src/access/BUILD.gn @@ -13,6 +13,25 @@ # limitations under the License. import("//build_overrides/chip.gni") +import("${chip_root}/build/chip/buildconfig_header.gni") +import("${chip_root}/src/access/access.gni") + +buildconfig_header("access_buildconfig") { + header = "AccessBuildConfig.h" + header_dir = "access" + + defines = [ + "CHIP_CONFIG_USE_ACCESS_RESTRICTIONS=${chip_enable_access_restrictions}", + ] + + visibility = [ ":access_config" ] +} + +source_set("access_config") { + sources = [ "AccessConfig.h" ] + + deps = [ ":access_buildconfig" ] +} source_set("types") { sources = [ @@ -23,6 +42,7 @@ source_set("types") { ] public_deps = [ + ":access_config", "${chip_root}/src/lib/core", "${chip_root}/src/lib/core:types", ] @@ -43,10 +63,18 @@ static_library("access") { cflags = [ "-Wconversion" ] public_deps = [ + ":access_config", ":types", "${chip_root}/src/lib/core", "${chip_root}/src/lib/core:types", "${chip_root}/src/lib/support", "${chip_root}/src/platform", ] + + if (chip_enable_access_restrictions) { + sources += [ + "AccessRestrictionProvider.cpp", + "AccessRestrictionProvider.h", + ] + } } diff --git a/src/access/RequestPath.h b/src/access/RequestPath.h index af791d73eb151d..920d5fed372eb1 100644 --- a/src/access/RequestPath.h +++ b/src/access/RequestPath.h @@ -30,7 +30,7 @@ enum class RequestType : uint8_t kAttributeReadRequest, kAttributeWriteRequest, kCommandInvokeRequest, - kEventReadOrSubscribeRequest + kEventReadRequest }; struct RequestPath diff --git a/src/access/SubjectDescriptor.h b/src/access/SubjectDescriptor.h index ec6abec0b38a30..9cde4102750d25 100644 --- a/src/access/SubjectDescriptor.h +++ b/src/access/SubjectDescriptor.h @@ -42,6 +42,10 @@ struct SubjectDescriptor // CASE Authenticated Tags (CATs) only valid if auth mode is CASE. CATValues cats; + + // Whether the subject is currently a pending commissionee. See `IsCommissioning` + // definition in Core Specification's ACL Architecture pseudocode. + bool isCommissioning = false; }; } // namespace Access diff --git a/src/access/access.gni b/src/access/access.gni new file mode 100644 index 00000000000000..bd18f1387b66b9 --- /dev/null +++ b/src/access/access.gni @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Project CHIP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +declare_args() { + # Enable ARL features of Access Control + chip_enable_access_restrictions = false +} diff --git a/src/access/tests/BUILD.gn b/src/access/tests/BUILD.gn index d8b43e6a17a01d..64226adfacc260 100644 --- a/src/access/tests/BUILD.gn +++ b/src/access/tests/BUILD.gn @@ -15,6 +15,7 @@ import("//build_overrides/build.gni") import("//build_overrides/chip.gni") import("//build_overrides/pigweed.gni") +import("${chip_root}/src/access/access.gni") import("${chip_root}/build/chip/chip_test_suite.gni") @@ -29,4 +30,8 @@ chip_test_suite("tests") { "${chip_root}/src/lib/support:test_utils", "${dir_pw_unit_test}", ] + + if (chip_enable_access_restrictions) { + test_sources += [ "TestAccessRestrictionProvider.cpp" ] + } } diff --git a/src/access/tests/TestAccessRestrictionProvider.cpp b/src/access/tests/TestAccessRestrictionProvider.cpp new file mode 100644 index 00000000000000..ddf58ae2488f4d --- /dev/null +++ b/src/access/tests/TestAccessRestrictionProvider.cpp @@ -0,0 +1,722 @@ +/* + * + * Copyright (c) 2024 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "access/AccessControl.h" +#include "access/AccessRestrictionProvider.h" +#include "access/examples/ExampleAccessControlDelegate.h" + +#include + +#include +#include +#include +namespace chip { +namespace Access { + +class TestAccessRestrictionProvider : public AccessRestrictionProvider +{ + CHIP_ERROR DoRequestFabricRestrictionReview(const FabricIndex fabricIndex, uint64_t token, const std::vector & arl) + { + return CHIP_NO_ERROR; + } +}; + +AccessControl accessControl; +TestAccessRestrictionProvider accessRestrictionProvider; + +constexpr ClusterId kNetworkCommissioningCluster = app::Clusters::NetworkCommissioning::Id; +constexpr ClusterId kDescriptorCluster = app::Clusters::Descriptor::Id; +constexpr ClusterId kOnOffCluster = app::Clusters::OnOff::Id; + +// Clusters allowed to have restrictions +constexpr ClusterId kWiFiNetworkManagementCluster = app::Clusters::WiFiNetworkManagement::Id; +constexpr ClusterId kThreadBorderRouterMgmtCluster = app::Clusters::ThreadBorderRouterManagement::Id; +constexpr ClusterId kThreadNetworkDirectoryCluster = app::Clusters::ThreadNetworkDirectory::Id; + +constexpr NodeId kOperationalNodeId1 = 0x1111111111111111; +constexpr NodeId kOperationalNodeId2 = 0x2222222222222222; +constexpr NodeId kOperationalNodeId3 = 0x3333333333333333; + +bool operator==(const AccessRestrictionProvider::Restriction & lhs, const AccessRestrictionProvider::Restriction & rhs) +{ + return lhs.restrictionType == rhs.restrictionType && lhs.id == rhs.id; +} + +bool operator==(const AccessRestrictionProvider::Entry & lhs, const AccessRestrictionProvider::Entry & rhs) +{ + return lhs.fabricIndex == rhs.fabricIndex && lhs.endpointNumber == rhs.endpointNumber && lhs.clusterId == rhs.clusterId && + lhs.restrictions == rhs.restrictions; +} + +struct AclEntryData +{ + FabricIndex fabricIndex = kUndefinedFabricIndex; + Privilege privilege = Privilege::kView; + AuthMode authMode = AuthMode::kNone; + NodeId subject; +}; + +constexpr AclEntryData aclEntryData[] = { + { + .fabricIndex = 1, + .privilege = Privilege::kAdminister, + .authMode = AuthMode::kCase, + .subject = kOperationalNodeId1, + }, + { + .fabricIndex = 2, + .privilege = Privilege::kAdminister, + .authMode = AuthMode::kCase, + .subject = kOperationalNodeId2, + }, +}; +constexpr size_t aclEntryDataCount = ArraySize(aclEntryData); + +struct CheckData +{ + SubjectDescriptor subjectDescriptor; + RequestPath requestPath; + Privilege privilege; + bool allow; +}; + +constexpr CheckData checkDataNoRestrictions[] = { + // Checks for implicit PASE + { .subjectDescriptor = { .fabricIndex = 0, .authMode = AuthMode::kPase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = 1, .endpoint = 1, .requestType = RequestType::kAttributeReadRequest, .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 0, .authMode = AuthMode::kPase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = 1, .endpoint = 1, .requestType = RequestType::kAttributeWriteRequest, .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kPase, .subject = kOperationalNodeId2 }, + .requestPath = { .cluster = 1, .endpoint = 1, .requestType = RequestType::kCommandInvokeRequest, .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 2, .authMode = AuthMode::kPase, .subject = kOperationalNodeId3 }, + .requestPath = { .cluster = 1, .endpoint = 1, .requestType = RequestType::kEventReadRequest, .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + // Checks for entry 0 + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = 1, .endpoint = 1, .requestType = RequestType::kAttributeReadRequest, .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = 1, .endpoint = 1, .requestType = RequestType::kAttributeWriteRequest, .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = 1, .endpoint = 1, .requestType = RequestType::kCommandInvokeRequest, .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = 1, .endpoint = 1, .requestType = RequestType::kEventReadRequest, .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + // Checks for entry 1 + { .subjectDescriptor = { .fabricIndex = 2, .authMode = AuthMode::kCase, .subject = kOperationalNodeId2 }, + .requestPath = { .cluster = 1, .endpoint = 1, .requestType = RequestType::kAttributeReadRequest, .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 2, .authMode = AuthMode::kCase, .subject = kOperationalNodeId2 }, + .requestPath = { .cluster = 1, .endpoint = 1, .requestType = RequestType::kAttributeWriteRequest, .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 2, .authMode = AuthMode::kCase, .subject = kOperationalNodeId2 }, + .requestPath = { .cluster = 1, .endpoint = 1, .requestType = RequestType::kCommandInvokeRequest, .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 2, .authMode = AuthMode::kCase, .subject = kOperationalNodeId2 }, + .requestPath = { .cluster = 1, .endpoint = 1, .requestType = RequestType::kEventReadRequest, .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, +}; + +CHIP_ERROR LoadEntry(AccessControl::Entry & entry, const AclEntryData & entryData) +{ + ReturnErrorOnFailure(entry.SetAuthMode(entryData.authMode)); + ReturnErrorOnFailure(entry.SetFabricIndex(entryData.fabricIndex)); + ReturnErrorOnFailure(entry.SetPrivilege(entryData.privilege)); + ReturnErrorOnFailure(entry.AddSubject(nullptr, entryData.subject)); + return CHIP_NO_ERROR; +} + +CHIP_ERROR LoadAccessControl(AccessControl & ac, const AclEntryData * entryData, size_t count) +{ + AccessControl::Entry entry; + for (size_t i = 0; i < count; ++i, ++entryData) + { + ReturnErrorOnFailure(ac.PrepareEntry(entry)); + ReturnErrorOnFailure(LoadEntry(entry, *entryData)); + ReturnErrorOnFailure(ac.CreateEntry(nullptr, entry)); + } + return CHIP_NO_ERROR; +} + +void RunChecks(const CheckData * checkData, size_t count) +{ + for (size_t i = 0; i < count; i++) + { + CHIP_ERROR expectedResult = checkData[i].allow ? CHIP_NO_ERROR : CHIP_ERROR_ACCESS_DENIED; + EXPECT_EQ(accessControl.Check(checkData[i].subjectDescriptor, checkData[i].requestPath, checkData[i].privilege), + expectedResult); + } +} + +class DeviceTypeResolver : public AccessControl::DeviceTypeResolver +{ +public: + bool IsDeviceTypeOnEndpoint(DeviceTypeId deviceType, EndpointId endpoint) override { return false; } +} testDeviceTypeResolver; + +class TestAccessRestriction : public ::testing::Test +{ +public: // protected + void SetUp() override + { + accessRestrictionProvider.SetCommissioningEntries(std::vector()); + accessRestrictionProvider.SetEntries(0, std::vector()); + accessRestrictionProvider.SetEntries(1, std::vector()); + accessRestrictionProvider.SetEntries(2, std::vector()); + } + + static void SetUpTestSuite() + { + ASSERT_EQ(chip::Platform::MemoryInit(), CHIP_NO_ERROR); + AccessControl::Delegate * delegate = Examples::GetAccessControlDelegate(); + SetAccessControl(accessControl); + GetAccessControl().SetAccessRestrictionProvider(&accessRestrictionProvider); + VerifyOrDie(GetAccessControl().Init(delegate, testDeviceTypeResolver) == CHIP_NO_ERROR); + EXPECT_EQ(LoadAccessControl(accessControl, aclEntryData, aclEntryDataCount), CHIP_NO_ERROR); + } + static void TearDownTestSuite() + { + GetAccessControl().Finish(); + ResetAccessControlToDefault(); + } +}; + +// basic data check without restrictions +TEST_F(TestAccessRestriction, MetaTest) +{ + for (const auto & checkData : checkDataNoRestrictions) + { + CHIP_ERROR expectedResult = checkData.allow ? CHIP_NO_ERROR : CHIP_ERROR_ACCESS_DENIED; + EXPECT_EQ(accessControl.Check(checkData.subjectDescriptor, checkData.requestPath, checkData.privilege), expectedResult); + } +} + +// ensure failure when adding restrictons on endpoint 0 (any cluster, including those allowed on other endpoints) +TEST_F(TestAccessRestriction, InvalidRestrictionsOnEndpointZeroTest) +{ + std::vector entries; + AccessRestrictionProvider::Entry entry; + entry.endpointNumber = 0; + entry.fabricIndex = 1; + entry.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kAttributeAccessForbidden }); + + entry.clusterId = kDescriptorCluster; + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_ERROR_INVALID_ARGUMENT); + + entries.clear(); + entry.clusterId = kNetworkCommissioningCluster; + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_ERROR_INVALID_ARGUMENT); + + entries.clear(); + entry.clusterId = kWiFiNetworkManagementCluster; + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_ERROR_INVALID_ARGUMENT); + + entries.clear(); + entry.clusterId = kThreadBorderRouterMgmtCluster; + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_ERROR_INVALID_ARGUMENT); + + entries.clear(); + entry.clusterId = kThreadNetworkDirectoryCluster; + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_ERROR_INVALID_ARGUMENT); + + // also test a cluster on endpoint 0 that isnt in the special allowed list + entries.clear(); + entry.clusterId = kOnOffCluster; + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_ERROR_INVALID_ARGUMENT); +} + +// ensure no failure adding restrictions on endpoint 1 for allowed clusters only: +// wifi network management, thread border router, thread network directory +TEST_F(TestAccessRestriction, ValidRestrictionsOnEndpointOneTest) +{ + std::vector entries; + AccessRestrictionProvider::Entry entry; + entry.endpointNumber = 1; + entry.fabricIndex = 1; + entry.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kAttributeAccessForbidden }); + + entry.clusterId = kWiFiNetworkManagementCluster; + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_NO_ERROR); + + entries.clear(); + entry.clusterId = kThreadBorderRouterMgmtCluster; + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_NO_ERROR); + + entries.clear(); + entry.clusterId = kThreadNetworkDirectoryCluster; + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_NO_ERROR); + + // also test a cluster on endpoint 1 that isnt in the special allowed list + entries.clear(); + entry.clusterId = kOnOffCluster; + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_ERROR_INVALID_ARGUMENT); +} + +TEST_F(TestAccessRestriction, InvalidRestrictionsOnEndpointOneTest) +{ + std::vector entries; + AccessRestrictionProvider::Entry entry; + entry.endpointNumber = 1; + entry.fabricIndex = 1; + entry.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kAttributeAccessForbidden }); + entry.clusterId = kOnOffCluster; + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_ERROR_INVALID_ARGUMENT); +} + +constexpr CheckData accessAttributeRestrictionTestData[] = { + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeWriteRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = false }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeReadRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = false }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kCommandInvokeRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kEventReadRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, +}; + +TEST_F(TestAccessRestriction, AccessAttributeRestrictionTest) +{ + std::vector entries; + AccessRestrictionProvider::Entry entry; + entry.fabricIndex = 1; + entry.endpointNumber = 1; + entry.clusterId = kWiFiNetworkManagementCluster; + entry.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kAttributeAccessForbidden }); + + // test wildcarded entity id + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_NO_ERROR); + RunChecks(accessAttributeRestrictionTestData, ArraySize(accessAttributeRestrictionTestData)); + + // test specific entity id + entries.clear(); + entry.restrictions[0].id.SetValue(1); + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_NO_ERROR); + RunChecks(accessAttributeRestrictionTestData, ArraySize(accessAttributeRestrictionTestData)); +} + +constexpr CheckData writeAttributeRestrictionTestData[] = { + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeWriteRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = false }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeReadRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kCommandInvokeRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kEventReadRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, +}; + +TEST_F(TestAccessRestriction, WriteAttributeRestrictionTest) +{ + std::vector entries; + AccessRestrictionProvider::Entry entry; + entry.fabricIndex = 1; + entry.endpointNumber = 1; + entry.clusterId = kWiFiNetworkManagementCluster; + entry.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kAttributeWriteForbidden }); + + // test wildcarded entity id + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_NO_ERROR); + RunChecks(writeAttributeRestrictionTestData, ArraySize(writeAttributeRestrictionTestData)); + + // test specific entity id + entries.clear(); + entry.restrictions[0].id.SetValue(1); + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_NO_ERROR); + RunChecks(writeAttributeRestrictionTestData, ArraySize(writeAttributeRestrictionTestData)); +} + +constexpr CheckData commandAttributeRestrictionTestData[] = { + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeWriteRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeReadRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kCommandInvokeRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = false }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kEventReadRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, +}; + +TEST_F(TestAccessRestriction, CommandRestrictionTest) +{ + std::vector entries; + AccessRestrictionProvider::Entry entry; + entry.fabricIndex = 1; + entry.endpointNumber = 1; + entry.clusterId = kWiFiNetworkManagementCluster; + entry.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kCommandForbidden }); + + // test wildcarded entity id + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_NO_ERROR); + RunChecks(commandAttributeRestrictionTestData, ArraySize(commandAttributeRestrictionTestData)); + + // test specific entity id + entries.clear(); + entry.restrictions[0].id.SetValue(1); + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_NO_ERROR); + RunChecks(commandAttributeRestrictionTestData, ArraySize(commandAttributeRestrictionTestData)); +} + +constexpr CheckData eventAttributeRestrictionTestData[] = { + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeWriteRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeReadRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kCommandInvokeRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kEventReadRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = false }, +}; + +TEST_F(TestAccessRestriction, EventRestrictionTest) +{ + std::vector entries; + AccessRestrictionProvider::Entry entry; + entry.fabricIndex = 1; + entry.endpointNumber = 1; + entry.clusterId = kWiFiNetworkManagementCluster; + entry.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kEventForbidden }); + + // test wildcarded entity id + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_NO_ERROR); + RunChecks(eventAttributeRestrictionTestData, ArraySize(eventAttributeRestrictionTestData)); + + // test specific entity id + entries.clear(); + entry.restrictions[0].id.SetValue(1); + entries.push_back(entry); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_NO_ERROR); + RunChecks(eventAttributeRestrictionTestData, ArraySize(eventAttributeRestrictionTestData)); +} + +constexpr CheckData combinedRestrictionTestData[] = { + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeWriteRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = false }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeReadRequest, + .entityId = 2 }, + .privilege = Privilege::kAdminister, + .allow = false }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeWriteRequest, + .entityId = 3 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeReadRequest, + .entityId = 4 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeReadRequest, + .entityId = 3 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kCommandInvokeRequest, + .entityId = 4 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, .authMode = AuthMode::kCase, .subject = kOperationalNodeId1 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kEventReadRequest, + .entityId = 5 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 2, .authMode = AuthMode::kCase, .subject = kOperationalNodeId2 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kCommandInvokeRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = false }, + { .subjectDescriptor = { .fabricIndex = 2, .authMode = AuthMode::kCase, .subject = kOperationalNodeId2 }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeWriteRequest, + .entityId = 2 }, + .privilege = Privilege::kAdminister, + .allow = true }, +}; + +TEST_F(TestAccessRestriction, CombinedRestrictionTest) +{ + // a restriction for all access to attribute 1 and 2, attributes 3 and 4 are allowed + std::vector entries1; + AccessRestrictionProvider::Entry entry1; + entry1.fabricIndex = 1; + entry1.endpointNumber = 1; + entry1.clusterId = kWiFiNetworkManagementCluster; + entry1.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kAttributeWriteForbidden }); + entry1.restrictions[0].id.SetValue(1); + entry1.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kAttributeAccessForbidden }); + entry1.restrictions[1].id.SetValue(2); + entries1.push_back(entry1); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries1), CHIP_NO_ERROR); + + // a restriction for fabric 2 that forbids command 1 and 2. Check that command 1 is blocked on invoke, but attribute 2 write is + // allowed + std::vector entries2; + AccessRestrictionProvider::Entry entry2; + entry2.fabricIndex = 2; + entry2.endpointNumber = 1; + entry2.clusterId = kWiFiNetworkManagementCluster; + entry2.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kCommandForbidden }); + entry2.restrictions[0].id.SetValue(1); + entry2.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kCommandForbidden }); + entry2.restrictions[1].id.SetValue(2); + entries2.push_back(entry2); + EXPECT_EQ(accessRestrictionProvider.SetEntries(2, entries2), CHIP_NO_ERROR); + + RunChecks(combinedRestrictionTestData, ArraySize(combinedRestrictionTestData)); +} + +TEST_F(TestAccessRestriction, AttributeStorageSeperationTest) +{ + std::vector commissioningEntries; + AccessRestrictionProvider::Entry entry1; + entry1.fabricIndex = 1; + entry1.endpointNumber = 1; + entry1.clusterId = kWiFiNetworkManagementCluster; + entry1.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kAttributeWriteForbidden }); + entry1.restrictions[0].id.SetValue(1); + commissioningEntries.push_back(entry1); + EXPECT_EQ(accessRestrictionProvider.SetCommissioningEntries(commissioningEntries), CHIP_NO_ERROR); + + std::vector entries; + AccessRestrictionProvider::Entry entry2; + entry2.fabricIndex = 2; + entry2.endpointNumber = 2; + entry2.clusterId = kThreadBorderRouterMgmtCluster; + entry2.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kCommandForbidden }); + entry2.restrictions[0].id.SetValue(2); + entries.push_back(entry2); + EXPECT_EQ(accessRestrictionProvider.SetEntries(2, entries), CHIP_NO_ERROR); + + auto commissioningEntriesFetched = accessRestrictionProvider.GetCommissioningEntries(); + std::vector arlEntriesFetched; + EXPECT_EQ(accessRestrictionProvider.GetEntries(2, arlEntriesFetched), CHIP_NO_ERROR); + EXPECT_EQ(commissioningEntriesFetched[0], entry1); + EXPECT_EQ(commissioningEntriesFetched.size(), static_cast(1)); + EXPECT_EQ(arlEntriesFetched[0], entry2); + EXPECT_EQ(arlEntriesFetched.size(), static_cast(1)); + EXPECT_FALSE(commissioningEntriesFetched[0] == arlEntriesFetched[0]); +} + +constexpr CheckData listSelectionDuringCommissioningData[] = { + { .subjectDescriptor = { .fabricIndex = 1, + .authMode = AuthMode::kCase, + .subject = kOperationalNodeId1, + .isCommissioning = true }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeReadRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, + { .subjectDescriptor = { .fabricIndex = 1, + .authMode = AuthMode::kCase, + .subject = kOperationalNodeId1, + .isCommissioning = true }, + .requestPath = { .cluster = kThreadBorderRouterMgmtCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeReadRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = false }, + { .subjectDescriptor = { .fabricIndex = 1, + .authMode = AuthMode::kCase, + .subject = kOperationalNodeId1, + .isCommissioning = false }, + .requestPath = { .cluster = kWiFiNetworkManagementCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeReadRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = false }, + { .subjectDescriptor = { .fabricIndex = 1, + .authMode = AuthMode::kCase, + .subject = kOperationalNodeId1, + .isCommissioning = false }, + .requestPath = { .cluster = kThreadBorderRouterMgmtCluster, + .endpoint = 1, + .requestType = RequestType::kAttributeReadRequest, + .entityId = 1 }, + .privilege = Privilege::kAdminister, + .allow = true }, +}; + +TEST_F(TestAccessRestriction, ListSelectiondDuringCommissioningTest) +{ + // during commissioning, read is allowed on WifiNetworkManagement and disallowed on ThreadBorderRouterMgmt + // after commissioning, read is disallowed on WifiNetworkManagement and allowed on ThreadBorderRouterMgmt + + std::vector entries; + AccessRestrictionProvider::Entry entry1; + entry1.fabricIndex = 1; + entry1.endpointNumber = 1; + entry1.clusterId = kThreadBorderRouterMgmtCluster; + entry1.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kAttributeAccessForbidden }); + entry1.restrictions[0].id.SetValue(1); + entries.push_back(entry1); + EXPECT_EQ(accessRestrictionProvider.SetCommissioningEntries(entries), CHIP_NO_ERROR); + + entries.clear(); + AccessRestrictionProvider::Entry entry2; + entry2.fabricIndex = 1; + entry2.endpointNumber = 1; + entry2.clusterId = kWiFiNetworkManagementCluster; + entry2.restrictions.push_back({ .restrictionType = AccessRestrictionProvider::Type::kAttributeAccessForbidden }); + entry2.restrictions[0].id.SetValue(1); + entries.push_back(entry2); + EXPECT_EQ(accessRestrictionProvider.SetEntries(1, entries), CHIP_NO_ERROR); + + RunChecks(listSelectionDuringCommissioningData, ArraySize(listSelectionDuringCommissioningData)); +} + +} // namespace Access +} // namespace chip diff --git a/src/app/EventManagement.cpp b/src/app/EventManagement.cpp index 4c5faab3cdcda4..2419d564d0bc7a 100644 --- a/src/app/EventManagement.cpp +++ b/src/app/EventManagement.cpp @@ -556,7 +556,7 @@ CHIP_ERROR EventManagement::CheckEventContext(EventLoadOutContext * eventLoadOut Access::RequestPath requestPath{ .cluster = event.mClusterId, .endpoint = event.mEndpointId, - .requestType = Access::RequestType::kEventReadOrSubscribeRequest, + .requestType = Access::RequestType::kEventReadRequest, .entityId = event.mEventId }; Access::Privilege requestPrivilege = RequiredPrivilege::ForReadEvent(path); CHIP_ERROR accessControlError = diff --git a/src/app/InteractionModelEngine.cpp b/src/app/InteractionModelEngine.cpp index 64d30bc6c0e42f..88c1777b6c8224 100644 --- a/src/app/InteractionModelEngine.cpp +++ b/src/app/InteractionModelEngine.cpp @@ -544,7 +544,7 @@ static bool CanAccessEvent(const Access::SubjectDescriptor & aSubjectDescriptor, { Access::RequestPath requestPath{ .cluster = aPath.mClusterId, .endpoint = aPath.mEndpointId, - .requestType = Access::RequestType::kEventReadOrSubscribeRequest }; + .requestType = Access::RequestType::kEventReadRequest }; // leave requestPath.entityId optional value unset to indicate wildcard CHIP_ERROR err = Access::GetAccessControl().Check(aSubjectDescriptor, requestPath, aNeededPrivilege); return (err == CHIP_NO_ERROR); @@ -555,7 +555,7 @@ static bool CanAccessEvent(const Access::SubjectDescriptor & aSubjectDescriptor, { Access::RequestPath requestPath{ .cluster = aPath.mClusterId, .endpoint = aPath.mEndpointId, - .requestType = Access::RequestType::kEventReadOrSubscribeRequest, + .requestType = Access::RequestType::kEventReadRequest, .entityId = aPath.mEventId }; CHIP_ERROR err = Access::GetAccessControl().Check(aSubjectDescriptor, requestPath, RequiredPrivilege::ForReadEvent(aPath)); return (err == CHIP_NO_ERROR); diff --git a/src/app/chip_data_model.gni b/src/app/chip_data_model.gni index 01d47a47cb7184..a1d9a811a1b1e7 100644 --- a/src/app/chip_data_model.gni +++ b/src/app/chip_data_model.gni @@ -436,6 +436,12 @@ template("chip_data_model") { "${_app_root}/clusters/${cluster}/PresetStructWithOwnedMembers.h", "${_app_root}/clusters/${cluster}/thermostat-delegate.h", ] + } else if (cluster == "access-control-server") { + sources += [ + "${_app_root}/clusters/${cluster}/${cluster}.cpp", + "${_app_root}/clusters/${cluster}/ArlEncoder.cpp", + "${_app_root}/clusters/${cluster}/ArlEncoder.h", + ] } else { sources += [ "${_app_root}/clusters/${cluster}/${cluster}.cpp" ] } diff --git a/src/app/clusters/access-control-server/ArlEncoder.cpp b/src/app/clusters/access-control-server/ArlEncoder.cpp new file mode 100644 index 00000000000000..810c7345a13c7e --- /dev/null +++ b/src/app/clusters/access-control-server/ArlEncoder.cpp @@ -0,0 +1,145 @@ +/* + * + * Copyright (c) 2024 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ArlEncoder.h" + +using namespace chip; +using namespace chip::app; +using namespace chip::Access; + +using Entry = AccessRestrictionProvider::Entry; +using EntryListener = AccessRestrictionProvider::Listener; +using StagingRestrictionType = Clusters::AccessControl::AccessRestrictionTypeEnum; +using StagingRestriction = Clusters::AccessControl::Structs::AccessRestrictionStruct::Type; + +namespace { + +CHIP_ERROR StageEntryRestrictions(const std::vector & source, + StagingRestriction destination[], size_t destinationCount) +{ + size_t count = source.size(); + if (count > 0 && count <= destinationCount) + { + for (size_t i = 0; i < count; i++) + { + const auto & restriction = source[i]; + ReturnErrorOnFailure(ArlEncoder::Convert(restriction.restrictionType, destination[i].type)); + + if (restriction.id.HasValue()) + { + destination[i].id.SetNonNull(restriction.id.Value()); + } + } + } + else + { + return CHIP_ERROR_INVALID_ARGUMENT; + } + + return CHIP_NO_ERROR; +} + +} // namespace + +namespace chip { +namespace app { + +CHIP_ERROR ArlEncoder::Convert(Clusters::AccessControl::AccessRestrictionTypeEnum from, + Access::AccessRestrictionProvider::Type & to) +{ + switch (from) + { + case StagingRestrictionType::kAttributeAccessForbidden: + to = AccessRestrictionProvider::Type::kAttributeAccessForbidden; + break; + case StagingRestrictionType::kAttributeWriteForbidden: + to = AccessRestrictionProvider::Type::kAttributeWriteForbidden; + break; + case StagingRestrictionType::kCommandForbidden: + to = AccessRestrictionProvider::Type::kCommandForbidden; + break; + case StagingRestrictionType::kEventForbidden: + to = AccessRestrictionProvider::Type::kEventForbidden; + break; + default: + return CHIP_ERROR_INVALID_ARGUMENT; + } + return CHIP_NO_ERROR; +} + +CHIP_ERROR ArlEncoder::Convert(Access::AccessRestrictionProvider::Type from, + Clusters::AccessControl::AccessRestrictionTypeEnum & to) +{ + switch (from) + { + case AccessRestrictionProvider::Type::kAttributeAccessForbidden: + to = StagingRestrictionType::kAttributeAccessForbidden; + break; + case AccessRestrictionProvider::Type::kAttributeWriteForbidden: + to = StagingRestrictionType::kAttributeWriteForbidden; + break; + case AccessRestrictionProvider::Type::kCommandForbidden: + to = StagingRestrictionType::kCommandForbidden; + break; + case AccessRestrictionProvider::Type::kEventForbidden: + to = StagingRestrictionType::kEventForbidden; + break; + default: + return CHIP_ERROR_INVALID_ARGUMENT; + } + return CHIP_NO_ERROR; +} + +CHIP_ERROR ArlEncoder::CommissioningEncodableEntry::Encode(TLV::TLVWriter & writer, TLV::Tag tag) const +{ + ReturnErrorOnFailure(Stage()); + ReturnErrorOnFailure(mStagingEntry.Encode(writer, tag)); + return CHIP_NO_ERROR; +} + +CHIP_ERROR ArlEncoder::EncodableEntry::EncodeForRead(TLV::TLVWriter & writer, TLV::Tag tag, FabricIndex fabric) const +{ + ReturnErrorOnFailure(Stage()); + ReturnErrorOnFailure(mStagingEntry.EncodeForRead(writer, tag, fabric)); + return CHIP_NO_ERROR; +} + +CHIP_ERROR ArlEncoder::CommissioningEncodableEntry::Stage() const +{ + mStagingEntry.endpoint = mEntry.endpointNumber; + mStagingEntry.cluster = mEntry.clusterId; + ReturnErrorOnFailure(StageEntryRestrictions(mEntry.restrictions, mStagingRestrictions, + sizeof(mStagingRestrictions) / sizeof(mStagingRestrictions[0]))); + mStagingEntry.restrictions = Span(mStagingRestrictions, mEntry.restrictions.size()); + + return CHIP_NO_ERROR; +} + +CHIP_ERROR ArlEncoder::EncodableEntry::Stage() const +{ + mStagingEntry.fabricIndex = mEntry.fabricIndex; + mStagingEntry.endpoint = mEntry.endpointNumber; + mStagingEntry.cluster = mEntry.clusterId; + ReturnErrorOnFailure(StageEntryRestrictions(mEntry.restrictions, mStagingRestrictions, + sizeof(mStagingRestrictions) / sizeof(mStagingRestrictions[0]))); + mStagingEntry.restrictions = Span(mStagingRestrictions, mEntry.restrictions.size()); + + return CHIP_NO_ERROR; +} + +} // namespace app +} // namespace chip diff --git a/src/app/clusters/access-control-server/ArlEncoder.h b/src/app/clusters/access-control-server/ArlEncoder.h new file mode 100644 index 00000000000000..8050bf4d739db6 --- /dev/null +++ b/src/app/clusters/access-control-server/ArlEncoder.h @@ -0,0 +1,113 @@ +/* + * + * Copyright (c) 2024 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace chip { +namespace app { + +/** + * This class provides facilities for converting between access restriction + * entries (as used by the system module) and access restriction entries (as used + * by the generated cluster code). + */ +class ArlEncoder +{ +public: + ArlEncoder() = default; + ~ArlEncoder() = default; + + static CHIP_ERROR Convert(Clusters::AccessControl::AccessRestrictionTypeEnum from, + Access::AccessRestrictionProvider::Type & to); + + static CHIP_ERROR Convert(Access::AccessRestrictionProvider::Type from, + Clusters::AccessControl::AccessRestrictionTypeEnum & to); + + /** + * Used for encoding commissionable access restriction entries. + * + * Typically used temporarily on the stack to encode: + * - source: system level access restriction entry + * - staging: generated cluster level code + */ + class CommissioningEncodableEntry + { + using Entry = Access::AccessRestrictionProvider::Entry; + using StagingEntry = Clusters::AccessControl::Structs::CommissioningAccessRestrictionEntryStruct::Type; + using StagingRestriction = Clusters::AccessControl::Structs::AccessRestrictionStruct::Type; + + public: + CommissioningEncodableEntry(const Entry & entry) : mEntry(entry) {} + + /** + * Encode the constructor-provided entry into the TLV writer. + */ + CHIP_ERROR Encode(TLV::TLVWriter & aWriter, TLV::Tag aTag) const; + + static constexpr bool kIsFabricScoped = false; + + private: + CHIP_ERROR Stage() const; + + Entry mEntry; + mutable StagingEntry mStagingEntry; + mutable StagingRestriction mStagingRestrictions[CHIP_CONFIG_ACCESS_RESTRICTION_MAX_RESTRICTIONS_PER_ENTRY]; + }; + + /** + * Used for encoding access restriction entries. + * + * Typically used temporarily on the stack to encode: + * - source: system level access restriction entry + * - staging: generated cluster level code + */ + class EncodableEntry + { + using Entry = Access::AccessRestrictionProvider::Entry; + using StagingEntry = Clusters::AccessControl::Structs::AccessRestrictionEntryStruct::Type; + using StagingRestriction = Clusters::AccessControl::Structs::AccessRestrictionStruct::Type; + + public: + EncodableEntry(const Entry & entry) : mEntry(entry) {} + + /** + * Encode the constructor-provided entry into the TLV writer. + */ + CHIP_ERROR EncodeForRead(TLV::TLVWriter & writer, TLV::Tag tag, FabricIndex fabric) const; + + FabricIndex GetFabricIndex() const { return mEntry.fabricIndex; } + + static constexpr bool kIsFabricScoped = true; + + private: + /** + * Constructor-provided entry is staged into a staging entry. + */ + CHIP_ERROR Stage() const; + + Entry mEntry; + mutable StagingEntry mStagingEntry; + mutable StagingRestriction mStagingRestrictions[CHIP_CONFIG_ACCESS_RESTRICTION_MAX_RESTRICTIONS_PER_ENTRY]; + }; +}; + +} // namespace app +} // namespace chip diff --git a/src/app/clusters/access-control-server/access-control-server.cpp b/src/app/clusters/access-control-server/access-control-server.cpp index 321f7aa92a483a..295b5535df2f95 100644 --- a/src/app/clusters/access-control-server/access-control-server.cpp +++ b/src/app/clusters/access-control-server/access-control-server.cpp @@ -17,6 +17,11 @@ #include +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS +#include "ArlEncoder.h" +#include +#endif + #include #include @@ -25,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -41,6 +47,10 @@ using Entry = AccessControl::Entry; using EntryListener = AccessControl::EntryListener; using ExtensionEvent = Clusters::AccessControl::Events::AccessControlExtensionChanged::Type; +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS +using ArlReviewEvent = Clusters::AccessControl::Events::FabricRestrictionReviewUpdate::Type; +#endif + // TODO(#13590): generated code doesn't automatically handle max length so do it manually constexpr int kExtensionDataMaxLength = 128; @@ -48,7 +58,12 @@ constexpr uint16_t kClusterRevision = 1; namespace { -class AccessControlAttribute : public AttributeAccessInterface, public EntryListener +class AccessControlAttribute : public AttributeAccessInterface, + public AccessControl::EntryListener +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + , + public AccessRestrictionProvider::Listener +#endif { public: AccessControlAttribute() : AttributeAccessInterface(Optional(0), AccessControlCluster::Id) {} @@ -64,8 +79,17 @@ class AccessControlAttribute : public AttributeAccessInterface, public EntryList CHIP_ERROR Write(const ConcreteDataAttributePath & aPath, AttributeValueDecoder & aDecoder) override; public: - void OnEntryChanged(const SubjectDescriptor * subjectDescriptor, FabricIndex fabric, size_t index, const Entry * entry, - ChangeType changeType) override; + void OnEntryChanged(const SubjectDescriptor * subjectDescriptor, FabricIndex fabric, size_t index, + const AccessControl::Entry * entry, AccessControl::EntryListener::ChangeType changeType) override; + +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + void MarkCommissioningRestrictionListChanged() override; + + void MarkRestrictionListChanged(FabricIndex fabricIndex) override; + + void OnFabricRestrictionReviewUpdate(FabricIndex fabricIndex, uint64_t token, Optional instruction, + Optional redirectUrl) override; +#endif private: /// Business logic implementation of write, returns generic CHIP_ERROR. @@ -78,6 +102,10 @@ class AccessControlAttribute : public AttributeAccessInterface, public EntryList CHIP_ERROR ReadExtension(AttributeValueEncoder & aEncoder); CHIP_ERROR WriteAcl(const ConcreteDataAttributePath & aPath, AttributeValueDecoder & aDecoder); CHIP_ERROR WriteExtension(const ConcreteDataAttributePath & aPath, AttributeValueDecoder & aDecoder); +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + CHIP_ERROR ReadCommissioningArl(AttributeValueEncoder & aEncoder); + CHIP_ERROR ReadArl(AttributeValueEncoder & aEncoder); +#endif } sAttribute; CHIP_ERROR LogExtensionChangedEvent(const AccessControlCluster::Structs::AccessControlExtensionStruct::Type & item, @@ -114,8 +142,8 @@ CHIP_ERROR CheckExtensionEntryDataFormat(const ByteSpan & data) TLV::TLVReader reader; reader.Init(data); - auto containerType = chip::TLV::kTLVType_List; - err = reader.Next(containerType, chip::TLV::AnonymousTag()); + auto containerType = TLV::kTLVType_List; + err = reader.Next(containerType, TLV::AnonymousTag()); VerifyOrReturnError(err == CHIP_NO_ERROR, CHIP_IM_GLOBAL_STATUS(ConstraintError)); err = reader.EnterContainer(containerType); @@ -123,7 +151,7 @@ CHIP_ERROR CheckExtensionEntryDataFormat(const ByteSpan & data) while ((err = reader.Next()) == CHIP_NO_ERROR) { - VerifyOrReturnError(chip::TLV::IsProfileTag(reader.GetTag()), CHIP_IM_GLOBAL_STATUS(ConstraintError)); + VerifyOrReturnError(TLV::IsProfileTag(reader.GetTag()), CHIP_IM_GLOBAL_STATUS(ConstraintError)); } VerifyOrReturnError(err == CHIP_END_OF_TLV, CHIP_IM_GLOBAL_STATUS(ConstraintError)); @@ -159,6 +187,12 @@ CHIP_ERROR AccessControlAttribute::ReadImpl(const ConcreteReadAttributePath & aP ReturnErrorOnFailure(GetAccessControl().GetMaxEntriesPerFabric(value)); return aEncoder.Encode(static_cast(value)); } +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + case AccessControlCluster::Attributes::CommissioningARL::Id: + return ReadCommissioningArl(aEncoder); + case AccessControlCluster::Attributes::Arl::Id: + return ReadArl(aEncoder); +#endif case AccessControlCluster::Attributes::ClusterRevision::Id: return aEncoder.Encode(kClusterRevision); } @@ -375,7 +409,7 @@ CHIP_ERROR AccessControlAttribute::WriteExtension(const ConcreteDataAttributePat } void AccessControlAttribute::OnEntryChanged(const SubjectDescriptor * subjectDescriptor, FabricIndex fabric, size_t index, - const Entry * entry, ChangeType changeType) + const AccessControl::Entry * entry, AccessControl::EntryListener::ChangeType changeType) { // NOTE: If the entry was changed internally by the system (e.g. creating // entries at startup from persistent storage, or deleting entries when a @@ -389,11 +423,11 @@ void AccessControlAttribute::OnEntryChanged(const SubjectDescriptor * subjectDes CHIP_ERROR err; AclEvent event{ .changeType = ChangeTypeEnum::kChanged, .fabricIndex = subjectDescriptor->fabricIndex }; - if (changeType == ChangeType::kAdded) + if (changeType == AccessControl::EntryListener::ChangeType::kAdded) { event.changeType = ChangeTypeEnum::kAdded; } - else if (changeType == ChangeType::kRemoved) + else if (changeType == AccessControl::EntryListener::ChangeType::kRemoved) { event.changeType = ChangeTypeEnum::kRemoved; } @@ -428,6 +462,86 @@ void AccessControlAttribute::OnEntryChanged(const SubjectDescriptor * subjectDes ChipLogError(DataManagement, "AccessControlCluster: event failed %" CHIP_ERROR_FORMAT, err.Format()); } +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS +CHIP_ERROR AccessControlAttribute::ReadCommissioningArl(AttributeValueEncoder & aEncoder) +{ + auto accessRestrictionProvider = GetAccessControl().GetAccessRestrictionProvider(); + + return aEncoder.EncodeList([&](const auto & encoder) -> CHIP_ERROR { + if (accessRestrictionProvider != nullptr) + { + auto entries = accessRestrictionProvider->GetCommissioningEntries(); + + for (auto & entry : entries) + { + ArlEncoder::CommissioningEncodableEntry encodableEntry(entry); + ReturnErrorOnFailure(encoder.Encode(encodableEntry)); + } + } + return CHIP_NO_ERROR; + }); +} + +CHIP_ERROR AccessControlAttribute::ReadArl(AttributeValueEncoder & aEncoder) +{ + auto accessRestrictionProvider = GetAccessControl().GetAccessRestrictionProvider(); + + return aEncoder.EncodeList([&](const auto & encoder) -> CHIP_ERROR { + if (accessRestrictionProvider != nullptr) + { + for (const auto & info : Server::GetInstance().GetFabricTable()) + { + auto fabric = info.GetFabricIndex(); + // get entries for fabric + std::vector entries; + ReturnErrorOnFailure(accessRestrictionProvider->GetEntries(fabric, entries)); + for (const auto & entry : entries) + { + ArlEncoder::EncodableEntry encodableEntry(entry); + ReturnErrorOnFailure(encoder.Encode(encodableEntry)); + } + } + } + return CHIP_NO_ERROR; + }); +} +void AccessControlAttribute::MarkCommissioningRestrictionListChanged() +{ + MatterReportingAttributeChangeCallback(kRootEndpointId, AccessControlCluster::Id, + AccessControlCluster::Attributes::CommissioningARL::Id); +} + +void AccessControlAttribute::MarkRestrictionListChanged(FabricIndex fabricIndex) +{ + MatterReportingAttributeChangeCallback(kRootEndpointId, AccessControlCluster::Id, AccessControlCluster::Attributes::Arl::Id); +} + +void AccessControlAttribute::OnFabricRestrictionReviewUpdate(FabricIndex fabricIndex, uint64_t token, + Optional instruction, Optional redirectUrl) +{ + CHIP_ERROR err; + ArlReviewEvent event{ .token = token, .fabricIndex = fabricIndex }; + + if (instruction.HasValue()) + { + event.instruction.SetNonNull(instruction.Value()); + } + + if (redirectUrl.HasValue()) + { + event.redirectURL.SetNonNull(redirectUrl.Value()); + } + + EventNumber eventNumber; + SuccessOrExit(err = LogEvent(event, kRootEndpointId, eventNumber)); + + return; + +exit: + ChipLogError(DataManagement, "AccessControlCluster: review event failed: %" CHIP_ERROR_FORMAT, err.Format()); +} +#endif + CHIP_ERROR ChipErrorToImErrorMap(CHIP_ERROR err) { // Map some common errors into an underlying IM error @@ -481,4 +595,88 @@ void MatterAccessControlPluginServerInitCallback() AttributeAccessInterfaceRegistry::Instance().Register(&sAttribute); GetAccessControl().AddEntryListener(sAttribute); + +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + auto accessRestrictionProvider = GetAccessControl().GetAccessRestrictionProvider(); + if (accessRestrictionProvider != nullptr) + { + accessRestrictionProvider->AddListener(sAttribute); + } +#endif +} + +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS +bool emberAfAccessControlClusterReviewFabricRestrictionsCallback( + CommandHandler * commandObj, const ConcreteCommandPath & commandPath, + const Clusters::AccessControl::Commands::ReviewFabricRestrictions::DecodableType & commandData) +{ + if (commandPath.mEndpointId != kRootEndpointId) + { + ChipLogError(DataManagement, "AccessControlCluster: invalid endpoint in ReviewFabricRestrictions request"); + commandObj->AddStatus(commandPath, Protocols::InteractionModel::Status::InvalidCommand); + return true; + } + + uint64_t token; + std::vector entries; + auto entryIter = commandData.arl.begin(); + while (entryIter.Next()) + { + AccessRestrictionProvider::Entry entry; + entry.fabricIndex = commandObj->GetAccessingFabricIndex(); + entry.endpointNumber = entryIter.GetValue().endpoint; + entry.clusterId = entryIter.GetValue().cluster; + + auto restrictionIter = entryIter.GetValue().restrictions.begin(); + while (restrictionIter.Next()) + { + AccessRestrictionProvider::Restriction restriction; + if (ArlEncoder::Convert(restrictionIter.GetValue().type, restriction.restrictionType) != CHIP_NO_ERROR) + { + ChipLogError(DataManagement, "AccessControlCluster: invalid restriction type conversion"); + commandObj->AddStatus(commandPath, Protocols::InteractionModel::Status::InvalidCommand); + return true; + } + + if (!restrictionIter.GetValue().id.IsNull()) + { + restriction.id.SetValue(restrictionIter.GetValue().id.Value()); + } + entry.restrictions.push_back(restriction); + } + + if (restrictionIter.GetStatus() != CHIP_NO_ERROR) + { + ChipLogError(DataManagement, "AccessControlCluster: invalid ARL data"); + commandObj->AddStatus(commandPath, Protocols::InteractionModel::Status::InvalidCommand); + return true; + } + + entries.push_back(entry); + } + + if (entryIter.GetStatus() != CHIP_NO_ERROR) + { + ChipLogError(DataManagement, "AccessControlCluster: invalid ARL data"); + commandObj->AddStatus(commandPath, Protocols::InteractionModel::Status::InvalidCommand); + return true; + } + + CHIP_ERROR err = GetAccessControl().GetAccessRestrictionProvider()->RequestFabricRestrictionReview( + commandObj->GetAccessingFabricIndex(), entries, token); + + if (err == CHIP_NO_ERROR) + { + Clusters::AccessControl::Commands::ReviewFabricRestrictionsResponse::Type response; + response.token = token; + commandObj->AddResponse(commandPath, response); + } + else + { + ChipLogError(DataManagement, "AccessControlCluster: restriction review failed: %" CHIP_ERROR_FORMAT, err.Format()); + commandObj->AddStatus(commandPath, Protocols::InteractionModel::ClusterStatusCode(err)); + } + + return true; } +#endif diff --git a/src/app/reporting/Engine.cpp b/src/app/reporting/Engine.cpp index 6d79e9e7b34af0..99d038a8c35471 100644 --- a/src/app/reporting/Engine.cpp +++ b/src/app/reporting/Engine.cpp @@ -341,7 +341,7 @@ CHIP_ERROR Engine::CheckAccessDeniedEventPaths(TLV::TLVWriter & aWriter, bool & Access::RequestPath requestPath{ .cluster = current->mValue.mClusterId, .endpoint = current->mValue.mEndpointId, - .requestType = RequestType::kEventReadOrSubscribeRequest, + .requestType = RequestType::kEventReadRequest, .entityId = current->mValue.mEventId }; Access::Privilege requestPrivilege = RequiredPrivilege::ForReadEvent(path); diff --git a/src/app/server/BUILD.gn b/src/app/server/BUILD.gn index 401356d7b753a4..58524c3a648aab 100644 --- a/src/app/server/BUILD.gn +++ b/src/app/server/BUILD.gn @@ -13,6 +13,7 @@ # limitations under the License. import("//build_overrides/chip.gni") +import("${chip_root}/src/access/access.gni") import("${chip_root}/src/app/common_flags.gni") import("${chip_root}/src/app/icd/icd.gni") diff --git a/src/app/server/Server.cpp b/src/app/server/Server.cpp index 22cd274ba87e39..426e3c686572c0 100644 --- a/src/app/server/Server.cpp +++ b/src/app/server/Server.cpp @@ -177,6 +177,13 @@ CHIP_ERROR Server::Init(const ServerInitParams & initParams) SuccessOrExit(err = mAccessControl.Init(initParams.accessDelegate, sDeviceTypeResolver)); Access::SetAccessControl(mAccessControl); +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + if (initParams.accessRestrictionProvider != nullptr) + { + mAccessControl.SetAccessRestrictionProvider(initParams.accessRestrictionProvider); + } +#endif + mAclStorage = initParams.aclStorage; SuccessOrExit(err = mAclStorage->Init(*mDeviceStorage, mFabrics.begin(), mFabrics.end())); diff --git a/src/app/server/Server.h b/src/app/server/Server.h index 2f6126a4ace635..27c850563f6e98 100644 --- a/src/app/server/Server.h +++ b/src/app/server/Server.h @@ -163,6 +163,13 @@ struct ServerInitParams // ACL storage: MUST be injected. Used to store ACL entries in persistent storage. Must NOT // be initialized before being provided. app::AclStorage * aclStorage = nullptr; + +#if CHIP_CONFIG_USE_ACCESS_RESTRICTIONS + // Access Restriction implementation: MUST be injected if MNGD feature enabled. Used to enforce + // access restrictions that are managed by the device. + Access::AccessRestrictionProvider * accessRestrictionProvider = nullptr; +#endif + // Network native params can be injected depending on the // selected Endpoint implementation void * endpointNativeParams = nullptr; diff --git a/src/credentials/FabricTable.cpp b/src/credentials/FabricTable.cpp index b847addbe10773..f999cf7a4f9a9d 100644 --- a/src/credentials/FabricTable.cpp +++ b/src/credentials/FabricTable.cpp @@ -71,6 +71,44 @@ constexpr size_t IndexInfoTLVMaxSize() return TLV::EstimateStructOverhead(sizeof(FabricIndex), CHIP_CONFIG_MAX_FABRICS * (1 + sizeof(FabricIndex)) + 1); } +CHIP_ERROR AddNewFabricForTestInternal(FabricTable & fabricTable, bool leavePending, const ByteSpan & rootCert, + const ByteSpan & icacCert, const ByteSpan & nocCert, const ByteSpan & opKeySpan, + FabricIndex * outFabricIndex) +{ + VerifyOrReturnError(outFabricIndex != nullptr, CHIP_ERROR_INVALID_ARGUMENT); + + CHIP_ERROR err = CHIP_ERROR_INTERNAL; + + Crypto::P256Keypair injectedOpKey; + Crypto::P256SerializedKeypair injectedOpKeysSerialized; + + Crypto::P256Keypair * opKey = nullptr; + if (!opKeySpan.empty()) + { + VerifyOrReturnError(opKeySpan.size() == injectedOpKeysSerialized.Capacity(), CHIP_ERROR_INVALID_ARGUMENT); + + memcpy(injectedOpKeysSerialized.Bytes(), opKeySpan.data(), opKeySpan.size()); + SuccessOrExit(err = injectedOpKeysSerialized.SetLength(opKeySpan.size())); + SuccessOrExit(err = injectedOpKey.Deserialize(injectedOpKeysSerialized)); + opKey = &injectedOpKey; + } + + SuccessOrExit(err = fabricTable.AddNewPendingTrustedRootCert(rootCert)); + SuccessOrExit(err = + fabricTable.AddNewPendingFabricWithProvidedOpKey(nocCert, icacCert, VendorId::TestVendor1, opKey, + /*isExistingOpKeyExternallyOwned =*/false, outFabricIndex)); + if (!leavePending) + { + SuccessOrExit(err = fabricTable.CommitPendingFabricData()); + } +exit: + if (err != CHIP_NO_ERROR) + { + fabricTable.RevertPendingFabricData(); + } + return err; +} + } // anonymous namespace CHIP_ERROR FabricInfo::Init(const FabricInfo::InitParams & initParams) @@ -695,34 +733,14 @@ CHIP_ERROR FabricTable::LoadFromStorage(FabricInfo * fabric, FabricIndex newFabr CHIP_ERROR FabricTable::AddNewFabricForTest(const ByteSpan & rootCert, const ByteSpan & icacCert, const ByteSpan & nocCert, const ByteSpan & opKeySpan, FabricIndex * outFabricIndex) { - VerifyOrReturnError(outFabricIndex != nullptr, CHIP_ERROR_INVALID_ARGUMENT); - - CHIP_ERROR err = CHIP_ERROR_INTERNAL; - - Crypto::P256Keypair injectedOpKey; - Crypto::P256SerializedKeypair injectedOpKeysSerialized; - - Crypto::P256Keypair * opKey = nullptr; - if (!opKeySpan.empty()) - { - VerifyOrReturnError(opKeySpan.size() == injectedOpKeysSerialized.Capacity(), CHIP_ERROR_INVALID_ARGUMENT); - - memcpy(injectedOpKeysSerialized.Bytes(), opKeySpan.data(), opKeySpan.size()); - SuccessOrExit(err = injectedOpKeysSerialized.SetLength(opKeySpan.size())); - SuccessOrExit(err = injectedOpKey.Deserialize(injectedOpKeysSerialized)); - opKey = &injectedOpKey; - } + return AddNewFabricForTestInternal(*this, /*leavePending=*/false, rootCert, icacCert, nocCert, opKeySpan, outFabricIndex); +} - SuccessOrExit(err = AddNewPendingTrustedRootCert(rootCert)); - SuccessOrExit(err = AddNewPendingFabricWithProvidedOpKey(nocCert, icacCert, VendorId::TestVendor1, opKey, - /*isExistingOpKeyExternallyOwned =*/false, outFabricIndex)); - SuccessOrExit(err = CommitPendingFabricData()); -exit: - if (err != CHIP_NO_ERROR) - { - RevertPendingFabricData(); - } - return err; +CHIP_ERROR FabricTable::AddNewUncommittedFabricForTest(const ByteSpan & rootCert, const ByteSpan & icacCert, + const ByteSpan & nocCert, const ByteSpan & opKeySpan, + FabricIndex * outFabricIndex) +{ + return AddNewFabricForTestInternal(*this, /*leavePending=*/true, rootCert, icacCert, nocCert, opKeySpan, outFabricIndex); } /* @@ -1546,6 +1564,16 @@ bool FabricTable::SetPendingDataFabricIndex(FabricIndex fabricIndex) return isLegal; } +FabricIndex FabricTable::GetPendingNewFabricIndex() const +{ + if (mStateFlags.Has(StateFlags::kIsAddPending)) + { + return mFabricIndexWithPendingState; + } + + return kUndefinedFabricIndex; +} + CHIP_ERROR FabricTable::AllocatePendingOperationalKey(Optional fabricIndex, MutableByteSpan & outputCsr) { // We can only manage commissionable pending fail-safe state if we have a keystore diff --git a/src/credentials/FabricTable.h b/src/credentials/FabricTable.h index aef60f6a17fbbe..af90d781283963 100644 --- a/src/credentials/FabricTable.h +++ b/src/credentials/FabricTable.h @@ -736,6 +736,18 @@ class DLL_EXPORT FabricTable */ bool HasOperationalKeyForFabric(FabricIndex fabricIndex) const; + /** + * @brief If a newly-added fabric is pending, this returns its index, or kUndefinedFabricIndex if none are pending. + * + * A newly-added fabric is pending if AddNOC has been previously called successfully but the + * fabric is not yet fully committed by CommissioningComplete. + * + * NOTE: that this never returns a value other than kUndefinedFabricIndex when UpdateNOC is pending. + * + * @return the fabric index of the pending fabric, or kUndefinedFabricIndex if no fabrics are pending. + */ + FabricIndex GetPendingNewFabricIndex() const; + /** * @brief Returns the operational keystore. This is used for * CASE and the only way the keystore should be used. @@ -968,6 +980,11 @@ class DLL_EXPORT FabricTable CHIP_ERROR AddNewFabricForTest(const ByteSpan & rootCert, const ByteSpan & icacCert, const ByteSpan & nocCert, const ByteSpan & opKeySpan, FabricIndex * outFabricIndex); + // Add a new fabric for testing. The Operational Key is a raw P256Keypair (public key and private key raw bits) that will + // get copied (directly) into the fabric table. The fabric will NOT be committed, and will remain pending. + CHIP_ERROR AddNewUncommittedFabricForTest(const ByteSpan & rootCert, const ByteSpan & icacCert, const ByteSpan & nocCert, + const ByteSpan & opKeySpan, FabricIndex * outFabricIndex); + // Same as AddNewFabricForTest, but ignore if we are colliding with same , so // that a single fabric table can have N nodes for same fabric. This usually works, but is bad form. CHIP_ERROR AddNewFabricForTestIgnoringCollisions(const ByteSpan & rootCert, const ByteSpan & icacCert, const ByteSpan & nocCert, diff --git a/src/credentials/tests/TestFabricTable.cpp b/src/credentials/tests/TestFabricTable.cpp index c5b78fc13e5b4f..0d34f48bce6599 100644 --- a/src/credentials/tests/TestFabricTable.cpp +++ b/src/credentials/tests/TestFabricTable.cpp @@ -551,6 +551,7 @@ TEST_F(TestFabricTable, TestBasicAddNocUpdateNocFlow) FabricTable & fabricTable = fabricTableHolder.GetFabricTable(); EXPECT_EQ(fabricTable.FabricCount(), 0); + EXPECT_EQ(fabricTable.GetPendingNewFabricIndex(), kUndefinedFabricIndex); { FabricIndex nextFabricIndex = kUndefinedFabricIndex; @@ -604,6 +605,7 @@ TEST_F(TestFabricTable, TestBasicAddNocUpdateNocFlow) EXPECT_EQ(fabricTable.FetchPendingNonFabricAssociatedRootCert(fetchedSpan), CHIP_NO_ERROR); EXPECT_TRUE(fetchedSpan.data_equal(rcac)); } + EXPECT_EQ(fabricTable.GetPendingNewFabricIndex(), kUndefinedFabricIndex); FabricIndex newFabricIndex = kUndefinedFabricIndex; bool keyIsExternallyOwned = true; @@ -614,6 +616,11 @@ TEST_F(TestFabricTable, TestBasicAddNocUpdateNocFlow) CHIP_NO_ERROR); EXPECT_EQ(newFabricIndex, 1); EXPECT_EQ(fabricTable.FabricCount(), 1); + + // After adding the pending new fabric (equivalent of AddNOC processing), the new + // fabric must be pending. + EXPECT_EQ(fabricTable.GetPendingNewFabricIndex(), 1); + { // No more pending root cert; it's associated with a fabric now. MutableByteSpan fetchedSpan{ rcacBuf }; @@ -661,6 +668,9 @@ TEST_F(TestFabricTable, TestBasicAddNocUpdateNocFlow) EXPECT_EQ(nextFabricIndex, 2); } + // Fabric can't be pending anymore. + EXPECT_EQ(fabricTable.GetPendingNewFabricIndex(), kUndefinedFabricIndex); + // Validate contents const auto * fabricInfo = fabricTable.FindFabricWithIndex(1); ASSERT_NE(fabricInfo, nullptr); @@ -732,12 +742,16 @@ TEST_F(TestFabricTable, TestBasicAddNocUpdateNocFlow) } EXPECT_EQ(fabricTable.AddNewPendingTrustedRootCert(rcac), CHIP_NO_ERROR); + EXPECT_EQ(fabricTable.GetPendingNewFabricIndex(), kUndefinedFabricIndex); + FabricIndex newFabricIndex = kUndefinedFabricIndex; EXPECT_EQ(fabricTable.FabricCount(), 1); EXPECT_EQ(fabricTable.AddNewPendingFabricWithOperationalKeystore(noc, icac, kVendorId, &newFabricIndex), CHIP_NO_ERROR); EXPECT_EQ(fabricTable.FabricCount(), 2); EXPECT_EQ(newFabricIndex, 2); + EXPECT_EQ(fabricTable.GetPendingNewFabricIndex(), 2); + // No storage yet EXPECT_EQ(storage.GetNumKeys(), numStorageAfterFirstAdd); // Next fabric index has not been updated yet. @@ -1897,6 +1911,8 @@ TEST_F(TestFabricTable, TestUpdateNocFailSafe) uint8_t csrBuf[chip::Crypto::kMIN_CSR_Buffer_Size]; MutableByteSpan csrSpan{ csrBuf }; + EXPECT_EQ(fabricTable.GetPendingNewFabricIndex(), kUndefinedFabricIndex); + // Make sure to tag fabric index to pending opkey: otherwise the UpdateNOC fails EXPECT_EQ(fabricTable.AllocatePendingOperationalKey(chip::MakeOptional(static_cast(1)), csrSpan), CHIP_NO_ERROR); @@ -1908,6 +1924,7 @@ TEST_F(TestFabricTable, TestUpdateNocFailSafe) EXPECT_EQ(fabricTable.FabricCount(), 1); EXPECT_EQ(fabricTable.UpdatePendingFabricWithOperationalKeystore(1, noc, ByteSpan{}), CHIP_NO_ERROR); + EXPECT_EQ(fabricTable.GetPendingNewFabricIndex(), kUndefinedFabricIndex); EXPECT_EQ(fabricTable.FabricCount(), 1); // No storage yet @@ -1936,6 +1953,7 @@ TEST_F(TestFabricTable, TestUpdateNocFailSafe) // Revert, should see Node ID 999 again fabricTable.RevertPendingFabricData(); EXPECT_EQ(fabricTable.FabricCount(), 1); + EXPECT_EQ(fabricTable.GetPendingNewFabricIndex(), kUndefinedFabricIndex); EXPECT_EQ(storage.GetNumKeys(), numStorageAfterAdd); diff --git a/src/lib/core/CHIPConfig.h b/src/lib/core/CHIPConfig.h index 7600b4eaa4c315..8936cc19c4f76d 100644 --- a/src/lib/core/CHIPConfig.h +++ b/src/lib/core/CHIPConfig.h @@ -1206,6 +1206,27 @@ extern const char CHIP_NON_PRODUCTION_MARKER[]; "Please enable at least one of CHIP_CONFIG_EXAMPLE_ACCESS_CONTROL_FAST_COPY_SUPPORT or CHIP_CONFIG_EXAMPLE_ACCESS_CONTROL_FLEXIBLE_COPY_SUPPORT" #endif +/** + * @def CHIP_CONFIG_ACCESS_RESTRICTION_MAX_ENTRIES_PER_FABRIC + * + * Defines the maximum number of access restriction list entries per + * fabric in the access control code's ARL attribute. + */ +#ifndef CHIP_CONFIG_ACCESS_RESTRICTION_MAX_ENTRIES_PER_FABRIC +#define CHIP_CONFIG_ACCESS_RESTRICTION_MAX_ENTRIES_PER_FABRIC 10 +#endif + +/** + * @def CHIP_CONFIG_ACCESS_RESTRICTION_MAX_RESTRICTIONS_PER_ENTRY + * + * Defines the maximum number of access restrictions for each entry + * in the ARL attribute (each entry is for a specific cluster on an + * endpoint on a fabric). + */ +#ifndef CHIP_CONFIG_ACCESS_RESTRICTION_MAX_RESTRICTIONS_PER_ENTRY +#define CHIP_CONFIG_ACCESS_RESTRICTION_MAX_RESTRICTIONS_PER_ENTRY 10 +#endif + /** * @def CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE * diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 68b6ed2852d377..6390e4eca1920e 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -92,6 +92,12 @@ class MockAppDelegate : public UnsolicitedMessageHandler, public ExchangeDelegat System::PacketBufferHandle && buffer) override { IsOnMessageReceivedCalled = true; + + if (ec->HasSessionHandle() && ec->GetSessionHolder()->IsSecureSession()) + { + mLastSubjectDescriptor = ec->GetSessionHolder()->AsSecureSession()->GetSubjectDescriptor(); + } + if (payloadHeader.IsAckMsg()) { mReceivedPiggybackAck = true; @@ -125,6 +131,7 @@ class MockAppDelegate : public UnsolicitedMessageHandler, public ExchangeDelegat EXPECT_EQ(buffer->TotalLength(), sizeof(PAYLOAD)); EXPECT_EQ(memcmp(buffer->Start(), PAYLOAD, buffer->TotalLength()), 0); + return CHIP_NO_ERROR; } @@ -151,6 +158,8 @@ class MockAppDelegate : public UnsolicitedMessageHandler, public ExchangeDelegat } } + Access::SubjectDescriptor mLastSubjectDescriptor{}; + bool IsOnMessageReceivedCalled = false; bool mReceivedPiggybackAck = false; bool mRetainExchange = false; @@ -1830,9 +1839,12 @@ TEST_F(TestReliableMessageProtocol, CheckApplicationResponseDelayed) EXPECT_EQ(loopback.mSentMessageCount, kMaxMRPTransmits); EXPECT_EQ(loopback.mDroppedMessageCount, kMaxMRPTransmits - 1); EXPECT_EQ(rm->TestGetCountRetransTable(), 1); // We have no ack yet. - EXPECT_TRUE(mockReceiver.IsOnMessageReceivedCalled); // Other side got the message. + ASSERT_TRUE(mockReceiver.IsOnMessageReceivedCalled); // Other side got the message. EXPECT_FALSE(mockSender.IsOnMessageReceivedCalled); // We did not get a response. + // It was not a commissioning CASE session so that is lined-up properly. + EXPECT_FALSE(mockReceiver.mLastSubjectDescriptor.isCommissioning); + // Ensure there will be no more weirdness with acks and that our MRP timer is restarted properly. mockReceiver.SetDropAckResponse(false); diff --git a/src/transport/SecureSession.cpp b/src/transport/SecureSession.cpp index c96f9cdf908756..7694df2e2e3aca 100644 --- a/src/transport/SecureSession.cpp +++ b/src/transport/SecureSession.cpp @@ -160,10 +160,11 @@ Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const Access::SubjectDescriptor subjectDescriptor; if (IsOperationalNodeId(mPeerNodeId)) { - subjectDescriptor.authMode = Access::AuthMode::kCase; - subjectDescriptor.subject = mPeerNodeId; - subjectDescriptor.cats = mPeerCATs; - subjectDescriptor.fabricIndex = GetFabricIndex(); + subjectDescriptor.authMode = Access::AuthMode::kCase; + subjectDescriptor.subject = mPeerNodeId; + subjectDescriptor.cats = mPeerCATs; + subjectDescriptor.fabricIndex = GetFabricIndex(); + subjectDescriptor.isCommissioning = IsCommissioningSession(); } else if (IsPAKEKeyId(mPeerNodeId)) { @@ -171,9 +172,10 @@ Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const // Initiator (aka commissioner) leaves subject descriptor unfilled. if (GetCryptoContext().IsResponder()) { - subjectDescriptor.authMode = Access::AuthMode::kPase; - subjectDescriptor.subject = mPeerNodeId; - subjectDescriptor.fabricIndex = GetFabricIndex(); + subjectDescriptor.authMode = Access::AuthMode::kPase; + subjectDescriptor.subject = mPeerNodeId; + subjectDescriptor.fabricIndex = GetFabricIndex(); + subjectDescriptor.isCommissioning = IsCommissioningSession(); } } else @@ -183,6 +185,24 @@ Access::SubjectDescriptor SecureSession::GetSubjectDescriptor() const return subjectDescriptor; } +bool SecureSession::IsCommissioningSession() const +{ + // PASE session is always a commissioning session. + if (IsPASESession()) + { + return true; + } + + // CASE session is a commissioning session if it was marked as such. + // The SessionManager is what keeps track. + if (IsCASESession() && mIsCaseCommissioningSession) + { + return true; + } + + return false; +} + void SecureSession::Retain() { #if CHIP_CONFIG_SECURE_SESSION_REFCOUNT_LOGGING diff --git a/src/transport/SecureSession.h b/src/transport/SecureSession.h index fe70d6714e4a0f..ba586a3095e3dd 100644 --- a/src/transport/SecureSession.h +++ b/src/transport/SecureSession.h @@ -156,6 +156,8 @@ class SecureSession : public Session, public ReferenceCountedStart(), msg->TotalLength())); CHIP_TRACE_MESSAGE_RECEIVED(payloadHeader, packetHeader, secureSession, peerAddress, msg->Start(), msg->TotalLength()); + + // Always recompute whether a message is for a commissioning session based on the latest knowledge of + // the fabric table. + if (secureSession->IsCASESession()) + { + secureSession->SetCaseCommissioningSessionStatus(secureSession->GetFabricIndex() == + mFabricTable->GetPendingNewFabricIndex()); + } mCB->OnMessageReceived(packetHeader, payloadHeader, session.Value(), isDuplicate, std::move(msg)); } else diff --git a/src/transport/tests/TestSessionManager.cpp b/src/transport/tests/TestSessionManager.cpp index 154071a56418c5..dfe3ca4fd5c5ff 100644 --- a/src/transport/tests/TestSessionManager.cpp +++ b/src/transport/tests/TestSessionManager.cpp @@ -28,6 +28,7 @@ #define CHIP_ENABLE_TEST_ENCRYPTED_BUFFER_API // Up here in case some other header // includes SessionManager.h indirectly +#include #include #include #include @@ -112,10 +113,12 @@ class TestSessMgrCallback : public SessionMessageDelegate } ReceiveHandlerCallCount++; + lastSubjectDescriptor = session->GetSubjectDescriptor(); } int ReceiveHandlerCallCount = 0; bool LargeMessageSent = false; + Access::SubjectDescriptor lastSubjectDescriptor{}; }; class TestSessionManager : public ::testing::Test @@ -141,7 +144,7 @@ TEST_F(TestSessionManager, CheckSimpleInitTest) &fabricTableHolder.GetFabricTable(), sessionKeystore)); } -TEST_F(TestSessionManager, CheckMessageTest) +TEST_F(TestSessionManager, CheckMessageOverPaseTest) { uint16_t payload_len = sizeof(PAYLOAD); @@ -213,7 +216,10 @@ TEST_F(TestSessionManager, CheckMessageTest) EXPECT_EQ(err, CHIP_NO_ERROR); mContext.DrainAndServiceIO(); - EXPECT_EQ(callback.ReceiveHandlerCallCount, 1); + ASSERT_EQ(callback.ReceiveHandlerCallCount, 1); + + // This was a PASE session so we expect the subject descriptor to indicate it's for commissioning. + EXPECT_TRUE(callback.lastSubjectDescriptor.isCommissioning); // Let's send the max sized message and make sure it is received chip::System::PacketBufferHandle large_buffer = chip::MessagePacketBuffer::NewWithData(LARGE_PAYLOAD, kMaxAppMessageLen);