Skip to content

Commit e8b25ab

Browse files
committed
SDPA support bf16 type on avx2_vnni_2 platforms with CPU ref impl
1 parent 9b01da1 commit e8b25ab

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/plugins/intel_cpu/src/nodes/scaled_attn.cpp

+11-4
Original file line numberDiff line numberDiff line change
@@ -1238,10 +1238,17 @@ void ScaledDotProductAttention::createPrimitive() {
12381238
std::shared_ptr<Executor> executor = nullptr;
12391239
#ifdef OPENVINO_ARCH_X86_64
12401240
if (rtPrecision == ov::element::bf16) {
1241-
executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(context,
1242-
m_key_quant_param.groupSize,
1243-
m_value_quant_param.groupSize,
1244-
m_key_quant_param.isByChannel);
1241+
if (ov::with_cpu_x86_bfloat16()) {
1242+
executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(context,
1243+
m_key_quant_param.groupSize,
1244+
m_value_quant_param.groupSize,
1245+
m_key_quant_param.isByChannel);
1246+
} else {
1247+
executor = std::make_shared<AttentionExecutor<KT_REF, ov::bfloat16>>(context,
1248+
m_key_quant_param.groupSize,
1249+
m_value_quant_param.groupSize,
1250+
m_key_quant_param.isByChannel);
1251+
}
12451252
} else if (rtPrecision == ov::element::f16) {
12461253
if (with_cpu_x86_avx512_core_fp16()) {
12471254
executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::float16>>(context,

0 commit comments

Comments
 (0)