|
5 | 5 | #include "shl_fullyconnected.hpp"
|
6 | 6 |
|
7 | 7 | #include "csinn/csi_nn.h"
|
| 8 | +#include "rvv/rvv.h" |
8 | 9 | #include "nodes/executors/executor.hpp"
|
9 | 10 | #include "nodes/executors/memory_arguments.hpp"
|
| 11 | +#include "nodes/common/cpu_memcpy.h" |
10 | 12 | #include "utils/debug_capabilities.h"
|
11 | 13 |
|
12 | 14 | namespace ov {
|
13 | 15 | namespace intel_cpu {
|
| 16 | +namespace { |
| 17 | +static MemoryPtr prepareWeightMemory(const MemoryPtr weightsMemory, const ExecutorContext::CPtr context) { |
| 18 | + DEBUG_LOG("ShlFCExecutor: prepack weights"); |
| 19 | + |
| 20 | + auto create = [&]() { |
| 21 | + const auto& weiDesc = weightsMemory->getDescPtr(); |
| 22 | + MemoryPtr _ptr = std::make_shared<Memory>(context->getEngine(), |
| 23 | + intel_cpu::CpuBlockedMemoryDesc(ov::element::f32, weightsMemory->getShape())); |
| 24 | + cpu_parallel_memcpy(_ptr->getData(), weightsMemory->getData(), weightsMemory->getSize()); |
| 25 | + DEBUG_LOG("ShlFCExecutor: cache miss, perform packing"); |
| 26 | + const auto repack_wei = ShlTensor(ShlSession(), precisionToShlDataType(weiDesc->getPrecision()), getShlDataLayoutByMemoryDesc(weiDesc, true), |
| 27 | + weiDesc->getShape().getStaticDims(), _ptr->getData()); |
| 28 | + shl_rvv_fc_gemm_reorder_weight_fp32(repack_wei.get()); |
| 29 | + return _ptr; |
| 30 | + }; |
| 31 | + |
| 32 | + auto weightCache = context->getWeightsCache(); |
| 33 | + if (weightCache != nullptr) { |
| 34 | + const auto& wgtDims = weightsMemory->getStaticDims(); |
| 35 | + std::string format = "gemm_shl_" + std::to_string(wgtDims[0]) + "_" + std::to_string(wgtDims[1]); |
| 36 | + const std::string string_hash = format + "_" + std::to_string(weightsMemory->getSize()) + "_" + |
| 37 | + std::to_string(reinterpret_cast<uint64_t>(weightsMemory->getData())); |
| 38 | + DEBUG_LOG("ShlFCExecutor: findOrCreate, string_hash: ", string_hash); |
| 39 | + return *weightCache->findOrCreate(string_hash, create); |
| 40 | + } |
| 41 | + |
| 42 | + DEBUG_LOG("ShlFCExecutor: Weights cache is not available"); |
| 43 | + return create(); |
| 44 | +} |
| 45 | +} // namespace |
14 | 46 |
|
15 | 47 | bool ShlFCExecutor::supports(const FCConfig& config) {
|
16 | 48 | if (config.attrs.weightsNonTransposed) {
|
@@ -53,7 +85,8 @@ bool ShlFCExecutor::supports(const FCConfig& config) {
|
53 | 85 | ShlFCExecutor::ShlFCExecutor(const FCAttrs& attrs,
|
54 | 86 | const PostOps& postOps,
|
55 | 87 | const MemoryArgs& memory,
|
56 |
| - const ExecutorContext::CPtr context) { |
| 88 | + const ExecutorContext::CPtr context) |
| 89 | + : packedWeights(prepareWeightMemory(memory.at(ARG_WEI), context)) { |
57 | 90 | const auto& srcDesc = memory.at(ARG_SRC)->getDescPtr();
|
58 | 91 | const auto& weiDesc = memory.at(ARG_WEI)->getDescPtr();
|
59 | 92 | const auto& dstDesc = memory.at(ARG_DST)->getDescPtr();
|
@@ -93,7 +126,7 @@ bool ShlFCExecutor::update(const MemoryArgs& memory) {
|
93 | 126 |
|
94 | 127 | void ShlFCExecutor::execute(const MemoryArgs& memory) {
|
95 | 128 | src.setData(memory.at(ARG_SRC)->getData());
|
96 |
| - wei.setData(memory.at(ARG_WEI)->getData()); |
| 129 | + wei.setData(packedWeights->getData()); |
97 | 130 | dst.setData(memory.at(ARG_DST)->getData());
|
98 | 131 | if (with_bias) {
|
99 | 132 | bias.setData(memory.at(ARG_BIAS)->getData());
|
|
0 commit comments