From 972a62fe106b379aa35055fac20448b54f461e6a Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Tue, 6 Aug 2024 19:16:17 +0200 Subject: [PATCH] Add FullyConnected ACL executor --- .../executors/acl/acl_common_executor.cpp | 134 +++++++ .../executors/acl/acl_common_executor.hpp | 58 +++ .../src/nodes/executors/acl/acl_eltwise.cpp | 119 +----- .../src/nodes/executors/acl/acl_eltwise.hpp | 2 +- .../executors/acl/acl_fullyconnected.cpp | 355 ++++++++++++++++++ .../executors/acl/acl_fullyconnected.hpp | 95 +++++ .../src/nodes/executors/acl/acl_utils.cpp | 61 ++- .../src/nodes/executors/acl/acl_utils.hpp | 38 +- .../src/nodes/executors/debug_messages.hpp | 2 + .../src/nodes/executors/executor.hpp | 20 +- .../fullyconnected_implementations.cpp | 45 +++ .../nodes/executors/implementation_utils.hpp | 5 + .../single_layer_tests/classes/matmul.cpp | 7 + .../instances/arm/matmul.cpp | 147 ++++++++ .../src/common/matmul_decompress_convert.cpp | 4 +- .../subgraph_tests/src/common/reshape_fc.cpp | 4 +- 16 files changed, 969 insertions(+), 127 deletions(-) create mode 100644 src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.hpp create mode 100644 src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp create mode 100644 src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/arm/matmul.cpp diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.cpp new file mode 100644 index 00000000000000..5779147a5b3352 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.cpp @@ -0,0 +1,134 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "acl_common_executor.hpp" +#include "acl_utils.hpp" +#include "nodes/executors/memory_arguments.hpp" +#include "utils/debug_capabilities.h" + +namespace ov { +namespace intel_cpu { + +static const std::unordered_map argConvert = { + {ARG_SRC_0, ACL_SRC_0}, + {ARG_SRC_1, ACL_SRC_1}, + {ARG_SRC_2, ACL_SRC_2}, + {ARG_BIAS, ACL_BIAS}, + {ARG_WEI, ACL_WEI}, + {ARG_DST, ACL_DST}, +}; + +using ACLTypes = std::array; +using ACLLayouts = std::array; + +static void initACLTensorParams(const MemoryPtr& memoryPtr, + const ACLTensorAttrs& attrs, + arm_compute::TensorShape& tensorShape, + arm_compute::DataType& dataType, + arm_compute::DataLayout& dataLayout) { + dataType = precisionToAclDataType(memoryPtr->getPrecision()); + dataLayout = getAclDataLayoutByMemoryDesc(memoryPtr->getDescPtr()); + if (dataType != arm_compute::DataType::UNKNOWN) { + auto collapsed_dims = collapse_dims_to_max_rank(memoryPtr->getStaticDims(), attrs.maxDimsShape); + tensorShape = shapeCast(collapsed_dims); + if (attrs.hasLayoutTypeNHWC) { + changeLayoutToNH_C({&tensorShape}); + } + } +} + +static std::shared_ptr initTensorInfo(const arm_compute::TensorShape& tensorShape, + const arm_compute::DataType& dataType, + const arm_compute::DataLayout& dataLayout) { + std::shared_ptr aclMemoryInfo = nullptr; + if (dataType != arm_compute::DataType::UNKNOWN) { + aclMemoryInfo = std::make_shared( + tensorShape, 1, + dataType, + dataLayout); + } + return aclMemoryInfo; +} + +static std::shared_ptr initTensor(const std::shared_ptr& aclMemoryInfo) { + std::shared_ptr aclMemory = nullptr; + if (aclMemoryInfo) { + aclMemory = std::make_shared(); + aclMemory->allocator()->init(*aclMemoryInfo); + } + return aclMemory; +} + +ACLCommonExecutor::ACLCommonExecutor() { + for (int i = 0; i < ACLArgs::COUNT_OF_ARGS; ++i) { + aclTensorAttrs.memoryUsageIndicator[i] = false; + } +} + +bool ACLCommonExecutor::update(const MemoryArgs &memory) { + // Initialize ACL tensors params + ACLShapes aclMemoryShapes; + ACLTypes aclDataType{}; + ACLLayouts aclDataLayout{}; + for (auto& cpu_mem_ptr : memory) { + const ACLArgs index = argConvert.at(cpu_mem_ptr.first); + initACLTensorParams(cpu_mem_ptr.second, aclTensorAttrs, + aclMemoryShapes[index], + aclDataType[index], + aclDataLayout[index]); + } + + // Update ACL tensors shapes + updateTensorsShapes(aclMemoryShapes); + + // Initialize arm_compute::TensorInfo objects + ACLInfos aclMemoryInfos; + for (int i = 0; i < ACLArgs::COUNT_OF_ARGS; i++) { + aclMemoryInfos[i] = initTensorInfo(aclMemoryShapes[i], aclDataType[i], aclDataLayout[i]); + } + + // Validate arm_compute::TensorInfo objects for specific ACL function + auto tensorsInfoValidateStatus = validateTensorsInfo(aclMemoryInfos); + if (!tensorsInfoValidateStatus) { + DEBUG_LOG("ACL operator validation failed: ", tensorsInfoValidateStatus.error_description()); + return false; + } + + // Initialize arm_compute::Tensor objects + for (int i = 0; i < ACLArgs::COUNT_OF_ARGS; i++) { + aclMemoryTensors[i] = initTensor(aclMemoryInfos[i]); + // Indicate that arm_compute::Tensor object can use import_memory function + if (aclMemoryTensors[i]) { + aclTensorAttrs.memoryUsageIndicator[i] = true; + } + } + + // Configure arm_compute::IFunction object + configureThreadSafe([&] { + iFunction = configureFunction(aclMemoryTensors); + }); + return true; +} + +void ACLCommonExecutor::execute(const MemoryArgs &memory) { + // TODO: Move import_memory() to update() function - CVS-145871 + for (auto& cpu_mem_ptr : memory) { + const ACLArgs index = argConvert.at(cpu_mem_ptr.first); + if (aclTensorAttrs.memoryUsageIndicator[index]) { + aclMemoryTensors[index]->allocator()->import_memory(memory.at(cpu_mem_ptr.first)->getData()); + } + } + iFunction->run(); +} + +ACLCommonExecutor::~ACLCommonExecutor() { + for (int i = 0; i < ACLArgs::COUNT_OF_ARGS; i++) { + if (aclTensorAttrs.memoryUsageIndicator[i]) { + aclMemoryTensors[i]->allocator()->free(); + } + } +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.hpp new file mode 100644 index 00000000000000..854130d6f884bb --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_common_executor.hpp @@ -0,0 +1,58 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "cpu_memory.h" +#include "nodes/executors/executor.hpp" +#include "arm_compute/runtime/NEON/NEFunctions.h" + +namespace ov { +namespace intel_cpu { + +enum ACLArgs { + ACL_SRC_0, + ACL_SRC_1, + ACL_SRC_2, + ACL_BIAS, + ACL_WEI, + ACL_DST, + COUNT_OF_ARGS +}; + +using ACLFunction = std::unique_ptr; +using ACLShapes = std::array; +using ACLInfos = std::array, ACLArgs::COUNT_OF_ARGS>; +using ACLTensors = std::array, ACLArgs::COUNT_OF_ARGS>; + +struct ACLTensorAttrs { + bool hasLayoutTypeNHWC = false; + size_t maxDimsShape = arm_compute::MAX_DIMS; + std::array memoryUsageIndicator; +}; + +class ACLCommonExecutor : public Executor { +public: + ACLCommonExecutor(); + virtual void updateTensorsShapes(ACLShapes& aclMemoryShapes) = 0; + virtual arm_compute::Status validateTensorsInfo(const ACLInfos& aclMemoryInfos) = 0; + virtual ACLFunction configureFunction(const ACLTensors& aclMemoryTensors) = 0; + impl_desc_type implType() const override { + return impl_desc_type::acl; + } + void execute(const MemoryArgs& memory) override; + bool update(const MemoryArgs& memory) override; + ~ACLCommonExecutor(); + +protected: + ACLTensorAttrs aclTensorAttrs; +private: + ACLTensors aclMemoryTensors; + ACLFunction iFunction = nullptr; +}; + +using ACLCommonExecutorPtr = std::shared_ptr; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.cpp index e22b1493f36ae2..942bacd91349ff 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.cpp @@ -361,66 +361,6 @@ bool AclEltwiseExecutor::init(const EltwiseAttrs &eltwiseAttrs, const std::vecto return acl_op; }; break; - case Algorithm::EltwiseRelu: - if (aclEltwiseAttrs.alpha == 0) { - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], - ActivationLayerInfo::ActivationFunction::RELU)) - return false; - } else { - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], - {ActivationLayerInfo::ActivationFunction::LEAKY_RELU, aclEltwiseAttrs.alpha})) - return false; - } - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - if (aclEltwiseAttrs.alpha == 0) { - acl_op->configure(&srcTensors[0], &dstTensors[0], ActivationLayerInfo::ActivationFunction::RELU); - } else { - acl_op->configure(&srcTensors[0], &dstTensors[0], - {ActivationLayerInfo::ActivationFunction::LEAKY_RELU, aclEltwiseAttrs.alpha}); - } - return acl_op; - }; - break; - case Algorithm::EltwiseGeluErf: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], ActivationLayerInfo::ActivationFunction::GELU)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], ActivationLayerInfo::ActivationFunction::GELU); - return acl_op; - }; - break; - case Algorithm::EltwiseElu: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], - {ActivationLayerInfo::ActivationFunction::ELU, aclEltwiseAttrs.alpha})) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], {ActivationLayerInfo::ActivationFunction::ELU, aclEltwiseAttrs.alpha}); - return acl_op; - }; - break; - case Algorithm::EltwiseTanh: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], - {ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f})) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], - {ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f}); - return acl_op; - }; - break; - case Algorithm::EltwiseSigmoid: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], ActivationLayerInfo::ActivationFunction::LOGISTIC)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], ActivationLayerInfo::ActivationFunction::LOGISTIC); - return acl_op; - }; - break; case Algorithm::EltwiseAbs: if (!NEAbsLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0])) return false; @@ -430,24 +370,6 @@ bool AclEltwiseExecutor::init(const EltwiseAttrs &eltwiseAttrs, const std::vecto return acl_op; }; break; - case Algorithm::EltwiseSqrt: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], ActivationLayerInfo::ActivationFunction::SQRT)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], ActivationLayerInfo::ActivationFunction::SQRT); - return acl_op; - }; - break; - case Algorithm::EltwiseSoftRelu: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], ActivationLayerInfo::ActivationFunction::SOFT_RELU)) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], ActivationLayerInfo::ActivationFunction::SOFT_RELU); - return acl_op; - }; - break; case Algorithm::EltwiseExp: if (!NEExpLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0])) return false; @@ -457,28 +379,6 @@ bool AclEltwiseExecutor::init(const EltwiseAttrs &eltwiseAttrs, const std::vecto return acl_op; }; break; - case Algorithm::EltwiseClamp: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], - {ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, aclEltwiseAttrs.beta, aclEltwiseAttrs.alpha})) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], - {ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, aclEltwiseAttrs.beta, aclEltwiseAttrs.alpha}); - return acl_op; - }; - break; - case Algorithm::EltwiseSwish: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], - {ActivationLayerInfo::ActivationFunction::SWISH, aclEltwiseAttrs.alpha})) - return false; - exec_func = [this]() -> std::unique_ptr { - auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], - {ActivationLayerInfo::ActivationFunction::SWISH, aclEltwiseAttrs.alpha}); - return acl_op; - }; - break; case Algorithm::EltwisePrelu: if (!NEPReluLayer::validate(&srcTensorsInfo[0], &srcTensorsInfo[1], &dstTensorsInfo[0])) return false; @@ -488,12 +388,27 @@ bool AclEltwiseExecutor::init(const EltwiseAttrs &eltwiseAttrs, const std::vecto return acl_op; }; break; + case Algorithm::EltwiseRelu: + case Algorithm::EltwiseGeluErf: + case Algorithm::EltwiseElu: + case Algorithm::EltwiseTanh: + case Algorithm::EltwiseSigmoid: + case Algorithm::EltwiseSqrt: + case Algorithm::EltwiseSoftRelu: + case Algorithm::EltwiseClamp: + case Algorithm::EltwiseSwish: case Algorithm::EltwiseHswish: - if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], ActivationLayerInfo::ActivationFunction::HARD_SWISH)) + if (!NEActivationLayer::validate(&srcTensorsInfo[0], &dstTensorsInfo[0], getActivationLayerInfo(aclEltwiseAttrs.algorithm, + aclEltwiseAttrs.alpha, + aclEltwiseAttrs.beta, + aclEltwiseAttrs.gamma))) return false; exec_func = [this]() -> std::unique_ptr { auto acl_op = std::make_unique(); - acl_op->configure(&srcTensors[0], &dstTensors[0], ActivationLayerInfo::ActivationFunction::HARD_SWISH); + acl_op->configure(&srcTensors[0], &dstTensors[0], getActivationLayerInfo(aclEltwiseAttrs.algorithm, + aclEltwiseAttrs.alpha, + aclEltwiseAttrs.beta, + aclEltwiseAttrs.gamma)); return acl_op; }; break; diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.hpp index 44af939d2c0474..6daf9e606c461b 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_eltwise.hpp @@ -16,7 +16,7 @@ class AclEltwiseExecutor : public EltwiseExecutor { explicit AclEltwiseExecutor(const ExecutorContext::CPtr context); static bool isEltwiseAlgorithmSupported(Algorithm algorithm); - bool init(const EltwiseAttrs& eltwiseAttrs, + bool init(const EltwiseAttrs& attrs, const std::vector& srcDescs, const std::vector& dstDescs, const std::vector& postOps) override; diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp new file mode 100644 index 00000000000000..99acb9070550dc --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.cpp @@ -0,0 +1,355 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "acl_fullyconnected.hpp" +#include "acl_utils.hpp" +#include "nodes/executors/executor.hpp" +#include "nodes/executors/memory_arguments.hpp" +#include "utils/debug_capabilities.h" +#include "nodes/executors/debug_messages.hpp" +#include "nodes/executors/implementation_utils.hpp" +#include "nodes/common/cpu_convert.h" +#include "memory_desc/cpu_memory_desc_utils.h" + +namespace ov { +namespace intel_cpu { + +static VectorDims makeDummyInputDims(const Shape& inShape, const Shape& wShape) { + const auto& weightDims = wShape.getStaticDims(); + + auto inMinDims = inShape.getMinDims(); + auto inMaxDims = inShape.getMaxDims(); + inMinDims.back() = weightDims.back(); + inMaxDims.back() = weightDims.back(); + + return MemoryDescUtils::makeDummyShape(Shape(inMinDims, inMaxDims)).getStaticDims(); +} + +static VectorDims makeDummyOutputDims(const VectorDims& inShape, const VectorDims& wShape, const size_t out_rank) { + size_t activationRank = inShape.size(); + size_t channelRank = wShape.size() - 1; + // activation weight output_shape + // NCHW CoCHW NCo + // TNC CoC TNCo + // NC CoC NCo + VectorDims outputShape(out_rank, 1); + // set Co + outputShape.back() = wShape[0]; + // set batch dims + size_t batchRank = activationRank - channelRank; + size_t startIdx = out_rank - batchRank - 1; + for (size_t i = 0; i < batchRank; i++) { + outputShape[i + startIdx] = inShape[i]; + } + + return outputShape; +} + +static MemoryPtr prepareWeightMemory(const MemoryArgs &memory, + const ExecutorContext::CPtr context, + const FCAttrs &attrs, + const ACLFCAttrs& aclfcAttrs, + const PostOps &postOps) { + DEBUG_LOG("ACLFullyConnectedExecutor: prepack weights"); + const auto& wgtDims = memory.at(ARG_WEI)->getStaticDims(); + const auto N = wgtDims[0]; + const auto K = wgtDims[1]; + + auto create = [&]() { + MemoryPtr final_ptr = memory.at(ARG_WEI); + // Convert weights precision + if (aclfcAttrs.isConvertedWeights) { + MemoryArgs memoryArgs; + memoryArgs[ARG_SRC_0] = memory.at(ARG_WEI); + memoryArgs[ARG_DST] = std::make_shared(context->getEngine(), + memoryArgs[ARG_SRC_0]->getDescPtr()->cloneWithNewPrecision( + aclfcAttrs.inputPrecision)); + auto aclWeightsConverter = std::make_shared(); + if (aclWeightsConverter->update(memoryArgs)) { + aclWeightsConverter->execute(memoryArgs); + } else { + auto count_wei_elem = std::accumulate(memoryArgs[ARG_SRC_0]->getStaticDims().begin(), + memoryArgs[ARG_SRC_0]->getStaticDims().end(), + 1, + std::multiplies<>()); + cpu_convert(memoryArgs[ARG_SRC_0]->getData(), + memoryArgs[ARG_DST]->getData(), + memoryArgs[ARG_SRC_0]->getPrecision(), + memoryArgs[ARG_DST]->getPrecision(), + count_wei_elem); + } + final_ptr = memoryArgs[ARG_DST]; + } + // Packed weights + { + arm_compute::WeightFormat expectedWeightFormat; + bool isNeededReorder; + { + MemoryArgs memoryArgs; + memoryArgs[ARG_BIAS] = memory.at(ARG_BIAS); + memoryArgs[ARG_WEI] = final_ptr; + if (memory.at(ARG_SRC_0)->getShape().isDynamic()) { + const auto& inShape = memory.at(ARG_SRC_0)->getShape(); + const auto& wShape = final_ptr->getShape(); + const auto& inDymmyDims = makeDummyInputDims(inShape, wShape); + const auto& outDymmyDims = makeDummyOutputDims(inDymmyDims, wShape.getStaticDims(), memory.at(ARG_DST)->getShape().getRank()); + memoryArgs[ARG_SRC_0] = std::make_shared(context->getEngine(), + memory.at(ARG_SRC_0)->getDescPtr()->cloneWithNewDims(inDymmyDims)); + memoryArgs[ARG_DST] = std::make_shared(context->getEngine(), + memory.at(ARG_DST)->getDescPtr()->cloneWithNewDims(outDymmyDims)); + } else { + memoryArgs[ARG_SRC_0] = memory.at(ARG_SRC_0); + memoryArgs[ARG_DST] = memory.at(ARG_DST); + } + auto aclWeightsRepack = std::make_shared(attrs, postOps, memoryArgs); + isNeededReorder = aclWeightsRepack->update(memoryArgs); + expectedWeightFormat = aclWeightsRepack->getOptImplWeightFormat(); + } + if (isNeededReorder) { + MemoryArgs memoryArgs; + memoryArgs[ARG_SRC_0] = final_ptr; + memoryArgs[ARG_DST] = std::make_shared(context->getEngine(), + memoryArgs[ARG_SRC_0]->getDescPtr()->clone()); + auto aclWeightsReorder = std::make_shared( + arm_compute::WeightFormat::OHWI, expectedWeightFormat); + if (aclWeightsReorder->update(memoryArgs)) { + aclWeightsReorder->execute(memoryArgs); + final_ptr = memoryArgs[ARG_DST]; + } + } + } + // Transpose weights + if (!aclfcAttrs.weightsNonTransposed) { + auto reverse_weights_dims = memory.at(ARG_WEI)->getStaticDims(); + if (reverse_weights_dims.size() == 3) { + reverse_weights_dims = VectorDims( + {reverse_weights_dims[0] * reverse_weights_dims[1], reverse_weights_dims[2]}); + } + std::reverse(reverse_weights_dims.begin(), reverse_weights_dims.end()); + MemoryArgs memoryArgs; + memoryArgs[ARG_SRC_0] = final_ptr; + memoryArgs[ARG_DST] = std::make_shared(context->getEngine(), + CpuBlockedMemoryDesc(final_ptr->getPrecision(), + intel_cpu::Shape(reverse_weights_dims))); + auto aclWeightsTranspose = std::make_shared(); + if (aclWeightsTranspose->update(memoryArgs)) { + aclWeightsTranspose->execute(memoryArgs); + final_ptr = memoryArgs[ARG_DST]; + } + } + DEBUG_LOG("ACLFullyConnectedExecutor: cache miss, perform packing"); + return final_ptr; + }; + + auto weightCache = context->getWeightsCache(); + if (weightCache != nullptr) { + std::string format = "fc_acl_" + std::to_string(N) + "_" + std::to_string(K); + const std::string string_hash = format + "_" + std::to_string(memory.at(ARG_WEI)->getSize()) + "_" + + std::to_string(reinterpret_cast(memory.at(ARG_WEI)->getData())); + DEBUG_LOG("ACLFullyConnectedExecutor: findOrCreate, string_hash: ", string_hash); + return *weightCache->findOrCreate(string_hash, create); + } + + DEBUG_LOG("ACLFullyConnectedExecutor: Weights cache is not available"); + return create(); +} + +static bool checkPostOps(const PostOps &postOps) { + // Add postops + if (!postOps.empty() && postOps.size() == 1) { + if (const auto activation = std::dynamic_pointer_cast(postOps[0])) { + if (checkActivationLayerInfo(convertToEltwiseAlgorithm(activation->type()))) { + return true; + } + } + } + return false; +} + +static void initFCAttrs(const FCAttrs &attrs, + ACLTensorAttrs& aclTensorAttrs, + ACLFCAttrs& aclfcAttrs, + const MemoryArgs &memory, + arm_compute::FullyConnectedLayerInfo& fullyConnectedLayerInfo, + const PostOps &postOps) { + aclTensorAttrs.hasLayoutTypeNHWC = memory.at(ARG_SRC)->getDescPtr()->hasLayoutType(LayoutType::nspc); + fullyConnectedLayerInfo.weights_trained_layout = getAclDataLayoutByMemoryDesc(memory.at(ARG_WEI)->getDescPtr()); + aclfcAttrs.inputPrecision = memory.at(ARG_SRC)->getDescPtr()->getPrecision(); + fullyConnectedLayerInfo.transpose_weights = false; + aclfcAttrs.weightsNonTransposed = attrs.weightsNonTransposed; + + if (checkPostOps(postOps)) { + auto activation = std::dynamic_pointer_cast(postOps[0]); + fullyConnectedLayerInfo.activation_info = getActivationLayerInfo( + convertToEltwiseAlgorithm(activation->type()), + activation->alpha(), activation->beta(), activation->gamma()); + } + + if (memory.at(ARG_SRC)->getPrecision() != memory.at(ARG_WEI)->getPrecision()) { + aclfcAttrs.isConvertedWeights = true; + } +} + +ACLFullyConnectedExecutor::ACLFullyConnectedExecutor(const FCAttrs &attrs, + const PostOps &postOps, + const MemoryArgs &memory, + const ExecutorContext::CPtr context) { + initFCAttrs(attrs, aclTensorAttrs, aclfcAttrs, memory, fullyConnectedLayerInfo, postOps); + packedWeights = prepareWeightMemory(memory, context, attrs, aclfcAttrs, postOps); +} + +bool ACLFullyConnectedExecutor::supports(const FCConfig &config) { + VERIFY(one_of(srcType(config), ov::element::f16, ov::element::f32), UNSUPPORTED_SRC_PRECISIONS); + VERIFY(one_of(weiType(config), ov::element::f16, ov::element::f32), UNSUPPORTED_WEI_PRECISIONS); + VERIFY(postOpsNumbers(config) < 2, UNSUPPORTED_NUMBER_OF_POSTOPS); + // This define is set because on ARM32 devices postOps in the function is empty (CVS-149013) +#ifndef OPENVINO_ARCH_ARM + VERIFY(checkPostOps(config.postOps), UNSUPPORTED_TYPE_OF_POSTOPS); +#endif + VERIFY(one_of(srcRank(config), 2U, 3U, 4U), UNSUPPORTED_SRC_RANK); + VERIFY(one_of(weiRank(config), 2U, 3U), UNSUPPORTED_WEI_RANK); + return true; +} + +static void updateFCTensorsShapes(ACLShapes& aclMemoryShapes) { + if (aclMemoryShapes[ACLArgs::ACL_WEI].num_dimensions() == 3U) { + aclMemoryShapes[ACLArgs::ACL_WEI] = arm_compute::TensorShape( + {aclMemoryShapes[ACLArgs::ACL_WEI][0] * aclMemoryShapes[ACLArgs::ACL_WEI][1], + aclMemoryShapes[ACLArgs::ACL_WEI][2]}); + } + + if (one_of(aclMemoryShapes[ACLArgs::ACL_SRC_0].num_dimensions(), 3U, 4U)) { + aclMemoryShapes[ACLArgs::ACL_SRC_0] = arm_compute::TensorShape({ + aclMemoryShapes[ACLArgs::ACL_WEI][0], + aclMemoryShapes[ACLArgs::ACL_SRC_0].total_size() / aclMemoryShapes[ACLArgs::ACL_WEI][0]}); + } + + if (one_of(aclMemoryShapes[ACLArgs::ACL_DST].num_dimensions(), 3U, 4U)) { + aclMemoryShapes[ACLArgs::ACL_DST] = arm_compute::TensorShape({ + aclMemoryShapes[ACLArgs::ACL_WEI][1], + aclMemoryShapes[ACLArgs::ACL_SRC_0][1]}); + } + + std::swap(aclMemoryShapes[ACLArgs::ACL_WEI][0], aclMemoryShapes[ACLArgs::ACL_WEI][1]); +} + +void ACLFullyConnectedExecutor::updateTensorsShapes(ACLShapes& aclMemoryShapes) { + updateFCTensorsShapes(aclMemoryShapes); +} + +arm_compute::Status ACLFullyConnectedExecutor::validateTensorsInfo(const ACLInfos & aclMemoryInfos) { + if (aclfcAttrs.isConvertedWeights) { + aclMemoryInfos[ACLArgs::ACL_WEI]->set_data_type(aclMemoryInfos[ACLArgs::ACL_SRC_0]->data_type()); + } + return arm_compute::NEFullyConnectedLayer::validate( + aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_WEI].get(), + aclMemoryInfos[ACLArgs::ACL_BIAS].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get(), + fullyConnectedLayerInfo, + weightsInfo); +} + +ACLFunction ACLFullyConnectedExecutor::configureFunction(const ACLTensors & aclMemoryTensors) { + auto neFC = std::make_unique(); + neFC->configure( + aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), + aclMemoryTensors[ACLArgs::ACL_WEI].get(), + aclMemoryTensors[ACLArgs::ACL_BIAS].get(), + aclMemoryTensors[ACLArgs::ACL_DST].get(), + fullyConnectedLayerInfo, + weightsInfo); + + if (aclfcAttrs.isConvertedWeights || !aclfcAttrs.weightsNonTransposed) { + aclTensorAttrs.memoryUsageIndicator[ACLArgs::ACL_WEI] = false; + aclMemoryTensors[ACLArgs::ACL_WEI]->allocator()->import_memory(packedWeights->getData()); + } + return neFC; +} + +arm_compute::Status acl_fc_executor::ACLWeightsConverter::validateTensorsInfo(const ACLInfos &aclMemoryInfos) { + return arm_compute::NECast::validate(aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get(), + arm_compute::ConvertPolicy::SATURATE); +} + +ACLFunction acl_fc_executor::ACLWeightsConverter::configureFunction(const ACLTensors &aclMemoryTensors) { + auto neCast = std::make_unique(); + neCast->configure(aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), + aclMemoryTensors[ACLArgs::ACL_DST].get(), + arm_compute::ConvertPolicy::SATURATE); + return neCast; +} + + +arm_compute::Status acl_fc_executor::ACLWeightsTranspose::validateTensorsInfo(const ACLInfos &aclMemoryInfos) { + return arm_compute::NETranspose::validate(aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get()); +} + +ACLFunction acl_fc_executor::ACLWeightsTranspose::configureFunction(const ACLTensors &aclMemoryTensors) { + auto neTranspose = std::make_unique(); + neTranspose->configure(aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), + aclMemoryTensors[ACLArgs::ACL_DST].get()); + return neTranspose; +} + +acl_fc_executor::ACLWeightFormatGenerator::ACLWeightFormatGenerator(const FCAttrs &attrs, + const PostOps &postOps, + const MemoryArgs &memory) { + initFCAttrs(attrs, aclTensorAttrs, aclfcAttrs, memory, fullyConnectedLayerInfo, postOps); +} + +void acl_fc_executor::ACLWeightFormatGenerator::updateTensorsShapes(ACLShapes &aclMemoryShapes) { + updateFCTensorsShapes(aclMemoryShapes); +} + +arm_compute::Status acl_fc_executor::ACLWeightFormatGenerator::validateTensorsInfo(const ACLInfos &aclMemoryInfos) { + if (aclfcAttrs.isConvertedWeights) { + aclMemoryInfos[ACLArgs::ACL_WEI]->set_data_type(aclMemoryInfos[ACLArgs::ACL_SRC_0]->data_type()); + } + return arm_compute::NEFullyConnectedLayer::has_opt_impl( + expectedWeightFormat, + aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_WEI].get(), + aclMemoryInfos[ACLArgs::ACL_BIAS].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get(), + fullyConnectedLayerInfo, + weightsInfo); +} + +ACLFunction acl_fc_executor::ACLWeightFormatGenerator::configureFunction(const ACLTensors &aclMemoryTensors) { + return std::make_unique(); +} + +arm_compute::Status acl_fc_executor::ACLWeightsReorder::validateTensorsInfo(const ACLInfos &aclMemoryInfos) { +#if defined(OPENVINO_ARCH_ARM64) + return arm_compute::NEReorderLayer::validate(aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get(), + inWeightFormat, + outWeightFormat); +#else + return arm_compute::NECopy::validate(aclMemoryInfos[ACLArgs::ACL_SRC_0].get(), + aclMemoryInfos[ACLArgs::ACL_DST].get()); +#endif +} + +ACLFunction acl_fc_executor::ACLWeightsReorder::configureFunction(const ACLTensors &aclMemoryTensors) { +#if defined(OPENVINO_ARCH_ARM64) + auto neReorderLayer = std::make_unique(); + neReorderLayer->configure(aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), + aclMemoryTensors[ACLArgs::ACL_DST].get(), + inWeightFormat, + outWeightFormat); + return neReorderLayer; +#else + auto neCopy = std::make_unique(); + neCopy->configure(aclMemoryTensors[ACLArgs::ACL_SRC_0].get(), + aclMemoryTensors[ACLArgs::ACL_DST].get()); + return neCopy; +#endif +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp new file mode 100644 index 00000000000000..4d7f2e5ef91480 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_fullyconnected.hpp @@ -0,0 +1,95 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "acl_common_executor.hpp" +#include "nodes/executors/fullyconnected_config.hpp" + +namespace ov { +namespace intel_cpu { + +struct ACLFCAttrs { + ov::element::Type inputPrecision; + bool isConvertedWeights = false; + bool weightsNonTransposed; +}; + +namespace acl_fc_executor { + +class ACLWeightsConverter : public ACLCommonExecutor { +public: + ACLWeightsConverter() = default; + void updateTensorsShapes(ACLShapes& aclMemoryShapes) override {} + arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; + ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; +}; + +class ACLWeightsTranspose : public ACLCommonExecutor { +public: + ACLWeightsTranspose() = default; + void updateTensorsShapes(ACLShapes& aclMemoryShapes) override {} + arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; + ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; +}; + +class ACLWeightFormatGenerator : public ACLCommonExecutor { +public: + ACLWeightFormatGenerator(const FCAttrs& attrs, + const PostOps& postOps, + const MemoryArgs& memory); + void updateTensorsShapes(ACLShapes& aclMemoryShapes) override; + arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; + ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; + arm_compute::WeightFormat getOptImplWeightFormat() { + return expectedWeightFormat; + } +private: + arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo; + arm_compute::WeightsInfo weightsInfo; + ACLFCAttrs aclfcAttrs; + arm_compute::WeightFormat expectedWeightFormat; +}; + +class ACLWeightsReorder : public ACLCommonExecutor { +public: + ACLWeightsReorder(arm_compute::WeightFormat inWeightFormat, + arm_compute::WeightFormat outWeightFormat) + : inWeightFormat(inWeightFormat), outWeightFormat(outWeightFormat) {} + void updateTensorsShapes(ACLShapes& aclMemoryShapes) override {} + arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; + ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; +private: + arm_compute::WeightFormat inWeightFormat; + arm_compute::WeightFormat outWeightFormat; +}; + +} // namespace acl_fc_executor + +class ACLFullyConnectedExecutor : public ACLCommonExecutor { +public: + ACLFullyConnectedExecutor(const FCAttrs& attrs, + const PostOps& postOps, + const MemoryArgs& memory, + const ExecutorContext::CPtr context); + + static bool supports(const FCConfig& config); + + void updateTensorsShapes(ACLShapes& aclMemoryShapes) override; + + arm_compute::Status validateTensorsInfo(const ACLInfos & aclMemoryInfos) override; + + ACLFunction configureFunction(const ACLTensors & aclMemoryTensors) override; + +private: + arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo; + arm_compute::WeightsInfo weightsInfo; + MemoryCPtr packedWeights; + ACLFCAttrs aclfcAttrs; +}; + +using ACLFullyConnectedExecutorPtr = std::shared_ptr; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.cpp index c2dfecbf57106c..df57d29f4a44ec 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.cpp @@ -4,10 +4,69 @@ #include "acl_utils.hpp" #include "support/Mutex.h" +#include "utils/debug_capabilities.h" -void ov::intel_cpu::configureThreadSafe(const std::function& config) { +namespace ov { +namespace intel_cpu { + +void configureThreadSafe(const std::function& config) { // Issue: CVS-123514 static arm_compute::Mutex mtx_config; arm_compute::lock_guard _lock{mtx_config}; config(); } + +arm_compute::ActivationLayerInfo getActivationLayerInfo(Algorithm algorithm, + float alpha = 0.0, + float beta = 0.0, + float gamma = 0.0) { + switch (algorithm) { + case Algorithm::EltwiseRelu: + if (alpha == 0) { + return arm_compute::ActivationLayerInfo::ActivationFunction::RELU; + } else { + return {arm_compute::ActivationLayerInfo::ActivationFunction::LEAKY_RELU, alpha}; + } + case Algorithm::EltwiseGeluErf: + return arm_compute::ActivationLayerInfo::ActivationFunction::GELU; + case Algorithm::EltwiseElu: + return {arm_compute::ActivationLayerInfo::ActivationFunction::ELU, alpha}; + case Algorithm::EltwiseTanh: + return {arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f}; + case Algorithm::EltwiseSigmoid: + return arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC; + case Algorithm::EltwiseSqrt: + return arm_compute::ActivationLayerInfo::ActivationFunction::SQRT; + case Algorithm::EltwiseSoftRelu: + return arm_compute::ActivationLayerInfo::ActivationFunction::SOFT_RELU; + case Algorithm::EltwiseClamp: + return {arm_compute::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, beta, alpha}; + case Algorithm::EltwiseSwish: + return {arm_compute::ActivationLayerInfo::ActivationFunction::SWISH, alpha}; + case Algorithm::EltwiseHswish: + return arm_compute::ActivationLayerInfo::ActivationFunction::HARD_SWISH; + default: + OPENVINO_THROW("Unsupported operation type for ACL Eltwise executor: ", static_cast(algorithm)); + } +} + +bool checkActivationLayerInfo(Algorithm algorithm) { + switch (algorithm) { + case Algorithm::EltwiseRelu: + case Algorithm::EltwiseGeluErf: + case Algorithm::EltwiseElu: + case Algorithm::EltwiseTanh: + case Algorithm::EltwiseSigmoid: + case Algorithm::EltwiseSqrt: + case Algorithm::EltwiseSoftRelu: + case Algorithm::EltwiseClamp: + case Algorithm::EltwiseSwish: + case Algorithm::EltwiseHswish: + return true; + default: + return false; + } +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp index a981c112e5e19d..a3d151192e601b 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp @@ -5,6 +5,7 @@ #include "memory_desc/cpu_memory_desc.h" #include "arm_compute/core/Types.h" +#include "cpu_types.h" namespace ov { namespace intel_cpu { @@ -15,15 +16,14 @@ namespace intel_cpu { * @param dims vector of dimensions to squash * @return vector of dimensions that complies to ACL */ -inline VectorDims collapse_dims_to_max_rank(VectorDims dims) { - const size_t MAX_NUM_SHAPE = arm_compute::MAX_DIMS; - VectorDims result_dims(MAX_NUM_SHAPE - 1); - if (dims.size() >= MAX_NUM_SHAPE) { - for (size_t i = 0; i < MAX_NUM_SHAPE - 1; i++) { +inline VectorDims collapse_dims_to_max_rank(VectorDims dims, size_t max_num_shape = arm_compute::MAX_DIMS) { + VectorDims result_dims(max_num_shape - 1); + if (dims.size() >= max_num_shape) { + for (size_t i = 0; i < max_num_shape - 1; i++) { result_dims[i] = dims[i]; } - for (size_t i = MAX_NUM_SHAPE - 1; i < dims.size(); i++) { - result_dims[MAX_NUM_SHAPE - 2] *= dims[i]; + for (size_t i = max_num_shape - 1; i < dims.size(); i++) { + result_dims[max_num_shape - 2] *= dims[i]; } } else { result_dims = dims; @@ -51,7 +51,7 @@ inline void changeLayoutToNH_C(const std::vector &_li } /** -* @brief Return ComputeLibrary TensorShape with reverted layout schema used in ACL +* @brief Return ComputeLibrary TensorShape with reverted layout schema used in ACL * @param dims vector of dimensions to convert * @return ComputeLibrary TensorShape object */ @@ -105,13 +105,6 @@ inline int axisCast(const std::size_t axis, const std::size_t shapeSize, ACLAxis } } -inline Dim vectorProduct(const VectorDims& vec, size_t size) { - Dim prod = 1; - for (size_t i = 0; i < size; ++i) - prod *= vec[i]; - return prod; -} - /** * @brief Return ComputeLibrary DataType that corresponds to the given precision * @param precision precision to be converted @@ -159,5 +152,20 @@ inline arm_compute::DataLayout getAclDataLayoutByMemoryDesc(MemoryDescCPtr desc) */ void configureThreadSafe(const std::function& config); +/** +* @brief get ARM Compute Library ActivationLayerInfo for Eltwise or PostOps. +* @param algorithm activation function of openvino representation +* @param alpha alpha coefficient for algorithm +* @param beta beta coefficient for algorithm +* @param gamma gamma coefficient for algorithm +*/ +arm_compute::ActivationLayerInfo getActivationLayerInfo(Algorithm algorithm, float alpha, float beta, float gamma); + +/** +* @brief check ARM Compute Library ActivationLayerInfo for Eltwise or PostOps. +* @param algorithm activation function of openvino representation +*/ +bool checkActivationLayerInfo(Algorithm algorithm); + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp b/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp index 2ee407564bf957..26ae6ace59631b 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp @@ -7,6 +7,8 @@ #define UNSUPPORTED_SPARSE_WEIGHTS " sparse weights are not supported" #define UNSUPPORTED_WEIGHTS_DECOMPRESSION " weights decompression is not supported" #define UNSUPPORTED_POST_OPS " post ops are not supported" +#define UNSUPPORTED_NUMBER_OF_POSTOPS " the number of post ops is not supported" +#define UNSUPPORTED_TYPE_OF_POSTOPS " the type of post ops is not supported" #define UNSUPPORTED_SRC_PRECISIONS " unsupported src precisions" #define UNSUPPORTED_WEI_PRECISIONS " unsupported wei precisions" #define UNSUPPORTED_DST_PRECISIONS " unsupported dst precisions" diff --git a/src/plugins/intel_cpu/src/nodes/executors/executor.hpp b/src/plugins/intel_cpu/src/nodes/executors/executor.hpp index 6b66c41e67b241..5b9df5a6e77a55 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/executor.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/executor.hpp @@ -24,9 +24,25 @@ namespace intel_cpu { #endif #if defined(OV_CPU_WITH_ACL) -# define OV_CPU_INSTANCE_ACL(...) {__VA_ARGS__}, +# if defined(OPENVINO_ARCH_ARM) +# define OV_CPU_INSTANCE_ACL32(...) {__VA_ARGS__}, +# else +# define OV_CPU_INSTANCE_ACL32(...) +# endif +# if defined(OPENVINO_ARCH_ARM64) +# define OV_CPU_INSTANCE_ACL64(...) {__VA_ARGS__}, +# else +# define OV_CPU_INSTANCE_ACL64(...) +# endif +# if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64) +# define OV_CPU_INSTANCE_ACL(...) {__VA_ARGS__}, +# else +# define OV_CPU_INSTANCE_ACL(...) +# endif #else -# define OV_CPU_INSTANCE_ACL(...) +# define OV_CPU_INSTANCE_ACL32(...) +# define OV_CPU_INSTANCE_ACL64(...) +# define OV_CPU_INSTANCE_ACL(...) #endif #if defined(OV_CPU_WITH_DNNL) diff --git a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp index d9caa23c9b10ec..5cf790fea50670 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp @@ -26,6 +26,10 @@ #include "ov_optional.hpp" #include "utils/cpp/maybe_unused.hpp" +#if defined(OV_CPU_WITH_ACL) +#include "nodes/executors/acl/acl_fullyconnected.hpp" +#endif + #if defined(OV_CPU_WITH_SHL) # include "nodes/executors/shl/shl_fullyconnected.hpp" #endif @@ -41,6 +45,7 @@ static const MappingNotation dnnlFCMappingNotation{ARG_SRC, ARG_WEI, ARG_BIAS, A using LayoutConfig = std::vector; static const LayoutConfig dnnlFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp}; +static const LayoutConfig aclFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp}; template struct Require { @@ -74,10 +79,20 @@ static const TypeMapping dnnlFCTypeMapping { // @todo explicitly cover configuration limitations for oneDNN on ARM }; +static const TypeMapping aclFCTypeMapping { + // {src, wei, bia, dst} pt + {{_f32 | _f16, _f32 | _f16, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, + {{_any, _any, _any, _any}, pt(just(), just(), just(), just())} +}; + static const MappingNotation dnnlConvolutionMappingNotation { ARG_SRC, ARG_WEI, ARG_BIAS, ARG_DST }; +static const MappingNotation aclFullyConnectedMappingNotation { + ARG_SRC, ARG_WEI, ARG_BIAS, ARG_DST +}; + static const TypeMapping dnnlConvolutionTypeMapping { // {src, wei, bia, dst} pt {{_bf16, _bf16 | _f32, _any, _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())}, @@ -305,6 +320,36 @@ const std::vector>& getImplementations() { context, false); }) + OV_CPU_INSTANCE_ACL( + "fullyconnected_acl", + ExecutorType::Acl, + OperationType::FullyConnected, + ShapeTolerance::Agnostic, + // supports + [](const FCConfig& config) -> bool { + VERIFY(noSparseDecompression(config), UNSUPPORTED_SPARSE_WEIGHTS); + VERIFY(noWeightsDecompression(config), UNSUPPORTED_WEIGHTS_DECOMPRESSION); + return ACLFullyConnectedExecutor::supports(config); + }, + // requiresFallback + [](const FCConfig& config) -> ov::optional> { + return requiresFallbackCommon(config, + aclFCTypeMapping, + aclFCLayoutConfig, + aclFullyConnectedMappingNotation); + }, + // acceptsShapes + [](const MemoryArgs& memory) -> bool { + // @todo create syntactic sugar (functor) for shape agnostic lambda + return true; + }, + // create + [](const FCAttrs& attrs, + const PostOps& postOps, + const MemoryArgs& memory, + const ExecutorContext::CPtr context) { + return std::make_shared(attrs, postOps, memory, context); + }) OV_CPU_INSTANCE_SHL( "fullyconnected_shl", ExecutorType::Shl, diff --git a/src/plugins/intel_cpu/src/nodes/executors/implementation_utils.hpp b/src/plugins/intel_cpu/src/nodes/executors/implementation_utils.hpp index 2382f5e4091a9f..cd029283a09c50 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/implementation_utils.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/implementation_utils.hpp @@ -83,5 +83,10 @@ size_t weiMemSize(const Config& config) { return memSize(config); } +template +size_t postOpsNumbers(const Config& config) { + return config.postOps.size(); +} + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/matmul.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/matmul.cpp index dd58da28c8ad05..50b6255b1cfae6 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/matmul.cpp @@ -125,7 +125,14 @@ void MatMulLayerCPUTest::SetUp() { rel_threshold = abs_threshold = 1e-2f; } else if (inference_precision == ov::element::f16) { inType = outType = netType = ElementType::f16; +#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64) + // rel_threshold = abs_threshold = 1e-2f; + // Temporarily created the following rel_threshold because of this bug CVS-144523 and + // https://github.com/ARM-software/ComputeLibrary/issues/1112 + rel_threshold = abs_threshold = 3e-1f; +#else rel_threshold = abs_threshold = 1e-4f; +#endif } else { inType = outType = netType; rel_threshold = 1e-4f; diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/arm/matmul.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/arm/matmul.cpp new file mode 100644 index 00000000000000..6d827614f80c54 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/arm/matmul.cpp @@ -0,0 +1,147 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "custom/single_layer_tests/classes/matmul.hpp" +#include "utils/cpu_test_utils.hpp" +#include "utils/filter_cpu_info.hpp" +#include "utils/fusing_test_utils.hpp" + +using namespace CPUTestUtils; + +namespace ov { +namespace test { +namespace MatMul { +/* ============= MatMul ============= */ +namespace matmul { + +static const std::vector& filterSpecificParamsFC() { + static const std::vector specificParams = {CPUSpecificParams{{}, {}, {"acl"}, "acl"}}; + return specificParams; +} + +std::vector fusingParamsSet2D_smoke { + emptyFusingSpec, + fusingBias, + fusingMultiplyPerChannel, + fusingRelu, + fusingTanh +}; + +const auto testParams2D_smoke = ::testing::Combine(::testing::Combine(::testing::ValuesIn(IS2D_smoke()), + ::testing::Values(ElementType::f32), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(emptyAdditionalConfig())), + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn(fusingParamsSet2D_smoke), + ::testing::ValuesIn(filterCPUInfo(filterSpecificParamsFC()))); +INSTANTIATE_TEST_SUITE_P(smoke_FC_2D, MatMulLayerCPUTest, testParams2D_smoke, MatMulLayerCPUTest::getTestCaseName); + + +std::vector fusingParamsSet2D_smoke_f16 { + emptyFusingSpec, + fusingBias, + fusingRelu +}; +const auto testParams2D_smoke_f16 = ::testing::Combine(::testing::Combine(::testing::ValuesIn(IS2D_smoke()), + ::testing::Values(ElementType::f16), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values( + ov::AnyMap({ov::hint::inference_precision(ov::element::f16)}))), + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn(fusingParamsSet2D_smoke_f16), + ::testing::ValuesIn(filterCPUInfo(filterSpecificParamsFC()))); +INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_f16, MatMulLayerCPUTest, testParams2D_smoke_f16, MatMulLayerCPUTest::getTestCaseName); + +std::vector fusingParamsSet3D_smoke { + emptyFusingSpec, + fusingBias, + fusingMultiplyPerChannel, + fusingRelu, + fusingTanh +}; +const auto fullyConnectedParams3D_smoke = ::testing::Combine(::testing::ValuesIn(IS3D_smoke()), + ::testing::Values(ElementType::f32), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(emptyAdditionalConfig())); +std::vector fusingParamsSet3D_smoke_f16 { + emptyFusingSpec, + fusingBias, + fusingRelu +}; +const auto fullyConnectedParams3D_smoke_f16 = ::testing::Combine(::testing::ValuesIn(IS3D_smoke()), + ::testing::Values(ElementType::f16), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values( + ov::AnyMap({ov::hint::inference_precision(ov::element::f16)}))); +const auto testParams3D_smoke = ::testing::Combine(fullyConnectedParams3D_smoke, + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn(fusingParamsSet3D_smoke), + ::testing::ValuesIn(filterCPUInfo(filterSpecificParamsFC()))); +const auto testParams3D_smoke_f16 = ::testing::Combine(fullyConnectedParams3D_smoke_f16, + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn(fusingParamsSet3D_smoke_f16), + ::testing::ValuesIn(filterCPUInfo(filterSpecificParamsFC()))); +INSTANTIATE_TEST_SUITE_P(smoke_FC_3D, MatMulLayerCPUTest, testParams3D_smoke, MatMulLayerCPUTest::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_f16, MatMulLayerCPUTest, testParams3D_smoke_f16, MatMulLayerCPUTest::getTestCaseName); + +const std::vector IS = { + {static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {false, false}}, + {static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {true, false}}, + {static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {false, true}}, + {static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {true, true}}, +}; + +std::vector fusingParamsSet4D_smoke { + emptyFusingSpec, + fusingMultiplyPerChannel, + fusingRelu, + fusingTanh +}; + +const auto testParams4D_smoke = ::testing::Combine(::testing::Combine(::testing::ValuesIn(IS), + ::testing::Values(ElementType::f32), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(emptyAdditionalConfig())), + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn(fusingParamsSet4D_smoke), + ::testing::ValuesIn(filterCPUInfo(filterSpecificParamsFC()))); +INSTANTIATE_TEST_SUITE_P(smoke_FC_4D, MatMulLayerCPUTest, testParams4D_smoke, MatMulLayerCPUTest::getTestCaseName); + +std::vector fusingParamsSet4D_smoke_f16 { + emptyFusingSpec, + fusingRelu +}; + +const auto testParams4D_smoke_f16 = ::testing::Combine(::testing::Combine(::testing::ValuesIn(IS), + ::testing::Values(ElementType::f16), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values( + ov::AnyMap({ov::hint::inference_precision(ov::element::f16)}))), + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn(fusingParamsSet4D_smoke_f16), + ::testing::ValuesIn(filterCPUInfo(filterSpecificParamsFC()))); +INSTANTIATE_TEST_SUITE_P(smoke_FC_4D_f16, MatMulLayerCPUTest, testParams4D_smoke_f16, MatMulLayerCPUTest::getTestCaseName); + +} // namespace matmul +} // namespace MatMul +} // namespace test +} // namespace ov diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/matmul_decompress_convert.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/matmul_decompress_convert.cpp index e72fbcf41a9b21..383385e9e5c1db 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/matmul_decompress_convert.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/matmul_decompress_convert.cpp @@ -277,9 +277,7 @@ std::vector filter_additional_config_bf16() { std::vector filter_specific_params(bool trySetMlas) { std::vector specificParams; -#if defined(OPENVINO_ARCH_ARM) - specificParams.push_back(CPUSpecificParams{{}, {}, {"gemm_ref"}, {"gemm_ref"}}); -#elif defined(OPENVINO_ARCH_ARM64) +#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64) specificParams.push_back(CPUSpecificParams{{}, {}, {"acl"}, "acl"}); #else if (trySetMlas) { diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/reshape_fc.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/reshape_fc.cpp index 91b3da3b937713..86507093ca65bb 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/reshape_fc.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/common/reshape_fc.cpp @@ -113,9 +113,7 @@ static std::vector filterFusingParams(const std::vector fusingParamsSet{emptyFusingSpec, fusingBias, fusingMultiplyPerChannel}; -#if defined(OPENVINO_ARCH_ARM) -const auto gemmParam = CPUSpecificParams{{}, {}, {"ref_any"}, "ref_any"}; -#elif defined(OPENVINO_ARCH_ARM64) +#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64) const auto gemmParam = CPUSpecificParams{{}, {}, {"acl"}, "acl"}; #elif OV_CPU_WITH_MLAS const auto gemmParam = CPUSpecificParams{{}, {}, {"gemm_mlas"}, "gemm_mlas"};