Skip to content

Commit b9d98cb

Browse files
authored
[CPU] Weights caching: hash computation fix (openvinotoolkit#25625)
### Details: - *Modify hash computation logic: take into account not only dnnl desc format, but all the desc info. Previous logic were not fully correct since the hash could be equal for 2 descs with different `compute_compensations` flag -- this led to accuracy issues* - *Weights repacking hash computation logic is moved to one helper which is reused across the CPU plugin code* ### Tickets: - *CVS-139671*
1 parent bb7f8d3 commit b9d98cb

File tree

7 files changed

+130
-50
lines changed

7 files changed

+130
-50
lines changed

src/plugins/intel_cpu/src/dnnl_extension_utils.cpp

+12-4
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
//
44

55
#include "dnnl_extension_utils.h"
6-
#include "memory_desc/dnnl_blocked_memory_desc.h"
7-
#include "onednn/iml_type_mapper.h"
8-
#include "utils/general_utils.h"
6+
97
#include <common/primitive_desc.hpp>
108
#include <common/primitive_desc_iface.hpp>
119
#include <oneapi/dnnl/dnnl.hpp>
12-
1310
#include <vector>
1411

12+
#include "cpu_memory.h"
13+
#include "memory_desc/dnnl_blocked_memory_desc.h"
14+
#include "onednn/iml_type_mapper.h"
15+
#include "utils/general_utils.h"
16+
1517
using namespace dnnl;
1618

1719
namespace ov {
@@ -254,5 +256,11 @@ bool DnnlExtensionUtils::isUnarySupportedAsPostOp(Algorithm alg) {
254256
#endif
255257
}
256258

259+
std::string DnnlExtensionUtils::computeWeightsStringHash(const std::shared_ptr<const IMemory> memory,
260+
const std::shared_ptr<DnnlMemoryDesc> dstDesc) {
261+
const auto desc_hash = dnnl::impl::primitive_hashing::get_md_hash(*dstDesc->getDnnlDesc().get());
262+
return std::to_string(desc_hash) + "_" + std::to_string(reinterpret_cast<uint64_t>(memory->getData()));
263+
}
264+
257265
} // namespace intel_cpu
258266
} // namespace ov

src/plugins/intel_cpu/src/dnnl_extension_utils.h

+8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace intel_cpu {
2222
class DnnlMemoryDesc;
2323
class DnnlBlockedMemoryDesc;
2424
class Shape;
25+
class IMemory;
2526

2627
class DnnlExtensionUtils {
2728
public:
@@ -101,6 +102,13 @@ class DnnlExtensionUtils {
101102
static dnnl_memory_desc_t clone_desc(const_dnnl_memory_desc_t cdesc);
102103
static const char* query_pd_info(const_dnnl_primitive_desc_t pd);
103104
static bool isUnarySupportedAsPostOp(Algorithm alg);
105+
/**
106+
* @brief Computes weights string hash based on weights memory and requested descriptor
107+
* @param memory Weights memory pointer
108+
* @param dstDesc descriptor defining weights representation after repacking
109+
* @return string hash
110+
*/
111+
static std::string computeWeightsStringHash(const std::shared_ptr<const IMemory> memory, const std::shared_ptr<DnnlMemoryDesc> dstDesc);
104112
};
105113

106114
} // namespace intel_cpu

src/plugins/intel_cpu/src/node.cpp

+3-14
Original file line numberDiff line numberDiff line change
@@ -831,16 +831,8 @@ void Node::prepareMemory(const DnnlMemoryDescPtr& intDesc, size_t indx) {
831831
MemoryPtr ptr;
832832
auto weightCache = context->getWeightsCache();
833833
if (weightCache != nullptr && memory::format_kind::blocked == intDesc->getDnnlDesc().get_format_kind()) {
834-
const auto& format = intDesc->serializeFormat();
835-
const uint64_t data_hash =
836-
weightCache->GetHashFunc().hash(static_cast<const unsigned char*>(internalBlob->getData()),
837-
internalBlob->getSize());
838-
839-
const std::string string_hash = name + "_" + std::to_string(indx)
840-
+ "_" + format
841-
+ "_" + std::to_string(internalBlob->getSize())
842-
+ "_" + std::to_string(data_hash);
843-
834+
const auto string_hash =
835+
name + "_" + std::to_string(indx) + "_" + DnnlExtensionUtils::computeWeightsStringHash(internalBlob, intDesc);
844836
ptr = *weightCache->findOrCreate(string_hash, create);
845837
} else {
846838
ptr = create();
@@ -905,10 +897,7 @@ MemoryPtr Node::prepareWeightMemory(DnnlMemoryDescPtr dstWeightDesc, DnnlMemoryD
905897

906898
auto weightCache = context->getWeightsCache();
907899
if (weightCache != nullptr) {
908-
const std::string string_hash = getName() + "_" + format
909-
+ "_" + std::to_string(edgeMem->getSize())
910-
+ "_" + std::to_string(*edgeMem->getDataAs<uint64_t>());
911-
900+
const auto string_hash = DnnlExtensionUtils::computeWeightsStringHash(edgeMem, dstWeightDesc);
912901
ptr = *weightCache->findOrCreate(string_hash, create);
913902
} else {
914903
ptr = create();

src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44

55
#include "nodes/executors/dnnl/dnnl_utils.hpp"
66

7+
#include <common/primitive_desc_iface.hpp>
78
#include <oneapi/dnnl/dnnl.hpp>
89

910
#include "cpu_memory.h"
1011
#include "memory_desc/dnnl_memory_desc.h"
12+
#include "memory_desc/cpu_memory_desc_utils.h"
1113
#include "nodes/executors/executor.hpp"
1214
#include "nodes/reorder.h"
15+
#include "utils/cpu_utils.hpp"
1316

1417
namespace ov {
1518
namespace intel_cpu {
@@ -86,8 +89,7 @@ MemoryPtr prepareWeightsMemory(const DnnlMemoryDescPtr srcWeightDesc,
8689
MemoryPtr ptr;
8790
if (globalWeightCache &&
8891
dnnl::memory::format_kind::blocked == dstWeightDesc->getDnnlDesc().get_format_kind()) {
89-
const std::string string_hash = format + "_" + std::to_string(weightsMem->getSize()) + "_" +
90-
std::to_string(reinterpret_cast<uint64_t>(weightsMem->getData()));
92+
const auto string_hash = DnnlExtensionUtils::computeWeightsStringHash(weightsMem, dstWeightDesc);
9193
ptr = *globalWeightCache->findOrCreate(string_hash, create);
9294
} else {
9395
ptr = create();

src/plugins/intel_cpu/src/weights_cache.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
namespace ov {
1111
namespace intel_cpu {
1212

13-
const SimpleDataHash WeightsSharing::simpleCRC;
14-
1513
WeightsSharing::SharedMemory::SharedMemory(
1614
std::unique_lock<std::mutex> && lock,
1715
const MemoryInfo::Ptr & memory,

src/plugins/intel_cpu/src/weights_cache.hpp

-28
Original file line numberDiff line numberDiff line change
@@ -22,31 +22,6 @@
2222

2323
namespace ov {
2424
namespace intel_cpu {
25-
26-
class SimpleDataHash {
27-
public:
28-
SimpleDataHash() {
29-
for (int i = 0; i < kTableSize; i++) {
30-
uint64_t c = i;
31-
for (int j = 0; j < 8; j++)
32-
c = ((c & 1) ? 0xc96c5795d7870f42 : 0) ^ (c >> 1);
33-
table[i] = c;
34-
}
35-
}
36-
// Computes 64-bit "cyclic redundancy check" sum, as specified in ECMA-182
37-
uint64_t hash(const unsigned char* data, size_t size) const {
38-
uint64_t crc = 0;
39-
for (size_t idx = 0; idx < size; idx++)
40-
crc = table[(unsigned char)crc ^ data[idx]] ^ (crc >> 8);
41-
42-
return ~crc;
43-
}
44-
45-
protected:
46-
static constexpr int kTableSize = 256;
47-
uint64_t table[kTableSize];
48-
};
49-
5025
/**
5126
* Caching store of Memory objects
5227
* Will return a cached object or create new one
@@ -94,12 +69,9 @@ class WeightsSharing {
9469

9570
SharedMemory::Ptr get(const std::string& key) const;
9671

97-
static const SimpleDataHash& GetHashFunc () { return simpleCRC; }
98-
9972
protected:
10073
mutable std::mutex guard;
10174
std::unordered_map<std::string, MemoryInfo::Ptr> sharedWeights;
102-
static const SimpleDataHash simpleCRC;
10375
};
10476

10577
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include <regex>
6+
7+
#include "common_test_utils/node_builders/constant.hpp"
8+
#include "common_test_utils/node_builders/fake_quantize.hpp"
9+
#include "common_test_utils/node_builders/reshape.hpp"
10+
#include "openvino/openvino.hpp"
11+
#include "shared_test_classes/base/ov_subgraph.hpp"
12+
13+
namespace ov {
14+
namespace test {
15+
16+
enum class FQInterval { U8, I8 };
17+
inline std::ostream& operator<<(std::ostream& os, FQInterval interval) {
18+
switch (interval) {
19+
case FQInterval::U8:
20+
os << "U8";
21+
break;
22+
case FQInterval::I8:
23+
os << "I8";
24+
break;
25+
default:
26+
OPENVINO_THROW("Unknown FQInterval");
27+
}
28+
return os;
29+
}
30+
31+
typedef std::tuple<InputShape, InputShape, FQInterval, FQInterval> QuantizedMatMulsWithSharedWeightsParans;
32+
33+
/* This test verifies the correctness of the hash function computation for the shared weights.
34+
Specifically, it checks that when one op requires compensations computation and second one does not,
35+
the resulting hashes are not identical, and the weights are repacked for each op separately
36+
*/
37+
class QuantizedMatMulsWithSharedWeightsTest
38+
: public testing::WithParamInterface<QuantizedMatMulsWithSharedWeightsParans>,
39+
virtual public SubgraphBaseTest {
40+
public:
41+
static std::string getTestCaseName(const testing::TestParamInfo<QuantizedMatMulsWithSharedWeightsParans>& obj) {
42+
InputShape shape1;
43+
InputShape shape2;
44+
FQInterval interval1;
45+
FQInterval interval2;
46+
std::tie(shape1, shape2, interval1, interval2) = obj.param;
47+
std::ostringstream result;
48+
result << "IS1=" << shape1 << "IS2=" << shape2 << "FQInterval1=" << interval1 << "FQInterval2=" << interval2;
49+
return result.str();
50+
}
51+
52+
void SetUp() override {
53+
targetDevice = ov::test::utils::DEVICE_CPU;
54+
abs_threshold = 1e-4;
55+
56+
InputShape shape1;
57+
InputShape shape2;
58+
FQInterval interval1;
59+
FQInterval interval2;
60+
std::tie(shape1, shape2, interval1, interval2) = this->GetParam();
61+
init_input_shapes({shape1, shape2});
62+
63+
const auto weights = ov::test::utils::make_constant(ov::element::i8, {16, 16});
64+
const auto convert = std::make_shared<ov::op::v0::Convert>(weights, ov::element::f32);
65+
const auto scale = ov::test::utils::make_constant(ov::element::f32, {16, 1}, ov::test::utils::InputGenerateData(0, 1, 5));
66+
const auto mul = std::make_shared<ov::op::v1::Multiply>(convert, scale);
67+
68+
auto build_fq = [](const ov::Output<ov::Node>& parent, FQInterval interval_type) {
69+
const auto low = interval_type == FQInterval::I8 ? std::vector<float>{-12.8f} : std::vector<float>{0.f};
70+
const auto high = interval_type == FQInterval::I8 ? std::vector<float>{12.7f} : std::vector<float>{25.5f};
71+
return ov::test::utils::make_fake_quantize(parent, ov::element::f32, 256, {1, 1, 1, 1}, low, high, low, high);
72+
};
73+
74+
const auto param1 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, inputDynamicShapes[0]);
75+
const auto fq1 = build_fq(param1, interval1);
76+
const auto mm1 = std::make_shared<ov::op::v0::MatMul>(fq1, mul, false, true);
77+
78+
const auto param2 = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, inputDynamicShapes[1]);
79+
const auto fq2 = build_fq(param2, interval2);
80+
const auto mm2 = std::make_shared<ov::op::v0::MatMul>(fq2, mul, false, true);
81+
82+
function = std::make_shared<ov::Model>(ov::OutputVector{mm1, mm2}, ov::ParameterVector{param1, param2});
83+
}
84+
};
85+
86+
TEST_P(QuantizedMatMulsWithSharedWeightsTest, CompareWithRefs) {
87+
run();
88+
}
89+
90+
namespace {
91+
92+
std::vector<InputShape> shapes1{{{-1, -1, -1, 16}, {{1, 1, 15, 16}, {1, 1, 12, 16}, {1, 1, 15, 16}}}};
93+
std::vector<InputShape> shapes2{{{-1, -1, -1, 16}, {{1, 1, 12, 16}, {1, 1, 15, 16}, {1, 1, 12, 16}}}};
94+
INSTANTIATE_TEST_SUITE_P(smoke_CustomTest, QuantizedMatMulsWithSharedWeightsTest,
95+
::testing::Combine(
96+
::testing::ValuesIn(shapes1),
97+
::testing::ValuesIn(shapes2),
98+
::testing::Values(FQInterval::U8, FQInterval::I8),
99+
::testing::Values(FQInterval::U8, FQInterval::I8)),
100+
QuantizedMatMulsWithSharedWeightsTest::getTestCaseName);
101+
} // namespace
102+
} // namespace test
103+
} // namespace ov

0 commit comments

Comments
 (0)