Skip to content

Commit c3c409e

Browse files
[CPU] Prohibit fc avx2_vnni_2 decompression for bf16 input (openvinotoolkit#23638)
### Details: - The FC changes made in scope of openvinotoolkit#20486 were missed when rebasing openvinotoolkit#20718 - The context is: Even the system and the node does support bf16 precision we have to fall back to f32 in/out precision due to lack of support for decompression with bf16 avx2_vnni_2 in oneDNN fork. - To cover this limitation an additional type mapping parameter in form of std::function was introduced for disabling particular type mapping entry using a runtime check (isa support in this case) ### Tickets: - 122347 - 136163
1 parent 36d9360 commit c3c409e

File tree

3 files changed

+52
-7
lines changed

3 files changed

+52
-7
lines changed

src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ static const MappingNotation dnnlFCMappingNotation{ARG_SRC, ARG_WEI, ARG_BIAS, A
3838
using LayoutConfig = std::vector<LayoutType>;
3939
static const LayoutConfig dnnlFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp};
4040

41+
template<dnnl::impl::cpu::x64::cpu_isa_t ISA>
42+
struct Require {
43+
bool operator()() {
44+
return dnnl::impl::cpu::x64::mayiuse(ISA);
45+
}
46+
};
47+
4148
// clang-format off
4249
static const TypeMapping dnnlFCTypeMapping {
4350
// {src, wei, bia, dst} pt<src, wei, bias, dst>
@@ -54,7 +61,10 @@ static const TypeMapping dnnlFCTypeMapping {
5461
{{_u8 | _i8, _i8, _any, _f16}, pt(bypass(), bypass(), just<f32>(), just<f32>())},
5562
{{_u8 | _i8, _i8, _any, _u8 | _i8 | _i32 | _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())},
5663
// compresses int weights (@todo more strict requrements for output precision?)
57-
{{_f32 | _bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
64+
{{_bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>()),
65+
Require<dnnl::impl::cpu::x64::avx512_core_bf16>()}, // Ticket 122347
66+
{{_bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(just<f32>(), bypass(), just<f32>(), just<f32>())},
67+
{{_f32, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
5868
// @todo should we fallback to FPXX instead of _f32?
5969
{{_any, _any, _any, _any}, pt(just<f32>(), just<f32>(), just<f32>(), just<f32>())},
6070
// @todo explicitly cover configuration limitations for oneDNN on ARM

src/plugins/intel_cpu/src/nodes/executors/precision_translation.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ InOutTypes getTypeConfiguration(const MemoryDescArgs& descriptors, const TypeMap
2121
});
2222

2323
for (const auto& entry : mapping) {
24-
const auto& pattern = entry.first;
24+
if (!entry.enabled())
25+
continue;
26+
27+
const auto& pattern = entry.mask();
2528
if (!match(pattern, types))
2629
continue;
2730

28-
const auto& translator = entry.second;
29-
return translator(types);
31+
return entry.translate(types);
3032
}
3133

3234
OPENVINO_THROW("Failed to create a type configuration for the provided memory descriptors");

src/plugins/intel_cpu/src/nodes/executors/precision_translation.hpp

+36-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
#include <cassert>
88
#include <functional>
9-
#include <utility>
109
#include <vector>
1110

1211
#include "nodes/executors/memory_arguments.hpp"
@@ -82,9 +81,43 @@ struct PortsTranslation {
8281
// pros: should be more efficient and safe
8382
// cons: more template instances (binary size) of the translation utility functions
8483
using InOutTypes = std::vector<ov::element::Type>;
85-
using PortsConfigurationImpl = std::function<InOutTypes(const InOutTypes&)>;
84+
using TypeTranslationFunction = std::function<InOutTypes(const InOutTypes&)>;
8685
using InOutTypeMask = std::vector<TypeMask>;
87-
using TypeMapping = std::vector<std::pair<InOutTypeMask, PortsConfigurationImpl>>;
86+
87+
class TypeMappingEntry {
88+
public:
89+
using EnabledPredicate = std::function<bool(void)>;
90+
91+
TypeMappingEntry(InOutTypeMask mask,
92+
TypeTranslationFunction translation,
93+
EnabledPredicate enabled = {})
94+
: m_mask(std::move(mask)),
95+
m_translation(std::move(translation)),
96+
m_enabled(std::move(enabled)) {}
97+
98+
const InOutTypeMask& mask() const {
99+
return m_mask;
100+
}
101+
102+
InOutTypes translate(const InOutTypes& types) const {
103+
if (m_translation)
104+
return m_translation(types);
105+
return {};
106+
}
107+
108+
bool enabled() const {
109+
if (m_enabled)
110+
return m_enabled();
111+
return true;
112+
}
113+
114+
private:
115+
InOutTypeMask m_mask;
116+
TypeTranslationFunction m_translation;
117+
EnabledPredicate m_enabled;
118+
};
119+
120+
using TypeMapping = std::vector<TypeMappingEntry>;
88121
using MappingNotation = std::vector<int>;
89122
using pt = PortsTranslation;
90123

0 commit comments

Comments
 (0)