Skip to content

Commit b659080

Browse files
committed
[CPU] FC: fixed primitive caching for sparse decompression case
1 parent bc685ac commit b659080

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/plugins/intel_cpu/src/nodes/fullyconnected.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,9 @@ void FullyConnected::createDescriptorInternal(const dnnl::memory::desc &inputDes
590590
// We need to explicitly specify the memory descriptor to use sparse weights decompression
591591
dnnl::memory::desc wgh_candidate;
592592
if (useSparseWeights) {
593+
// If we pass the true nnzCount value every time, then primitive caching will not work in this scenario. But since we don't use this value
594+
// anywhere else (only sparse encoding is important to us), we can pass an arbitrary fixed value (greater than 0) so that caching works
595+
int nnzCount = 1;
593596
wgh_candidate = { DnnlExtensionUtils::convertToDnnlDims(getInputShapeAtPort(WEIGHTS_ID).getStaticDims()),
594597
wdt, memory::desc::packed(nnzCount) };
595598
} else {
@@ -930,10 +933,9 @@ bool FullyConnected::useSparseWeightsDecompression() {
930933
zerosCounts++;
931934
}
932935
}
933-
nnzCount = elementsCount - zerosCounts;
934936

935937
DEBUG_LOG(getName(), ", weightsData.size() = ", elementsCount, ", zerosCounts = ",
936-
zerosCounts, ", nnzCount = ", nnzCount);
938+
zerosCounts, ", nnzCount = ", elementsCount - zerosCounts);
937939

938940
weiSparseRate = static_cast<float>(zerosCounts) / static_cast<float>(elementsCount);
939941

src/plugins/intel_cpu/src/nodes/fullyconnected.h

-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ class FullyConnected : public Node {
112112

113113
// sparse weights
114114
bool useSparseWeights = false;
115-
int nnzCount = -1;
116115
float minSparseRate = 1.f;
117116
float weiSparseRate = 0.f;
118117
bool useSparseWeightsDecompression();

0 commit comments

Comments
 (0)