Skip to content

Commit 2d14fa9

Browse files
authored
[CPU][ARM][x64]Snippets MatMul via brgemm emitter and executor (#28304)
### Details: - *Snippets MatMul via block wised brgemm emitter and executor on aarch64 with TPP* - *Snippets MatMul via block wised brgemm emitter and executor on x64 with TPP* ### Tickets: - *CVS-151344*
1 parent 43fb02a commit 2d14fa9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+946
-420
lines changed

cmake/features.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ ov_dependent_option (ENABLE_GPU_DEBUG_CAPS "enable GPU debug capabilities at run
5252
ov_dependent_option (ENABLE_CPU_DEBUG_CAPS "enable CPU debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS;ENABLE_INTEL_CPU" OFF)
5353
ov_dependent_option (ENABLE_SNIPPETS_DEBUG_CAPS "enable Snippets debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS" OFF)
5454

55-
ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU AND X86_64" OFF)
55+
ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU AND (X86_64 OR AARCH64)" OFF)
5656

5757
ov_option (ENABLE_PROFILING_ITT "Build with ITT tracing. Optionally configure pre-built ittnotify library though INTEL_VTUNE_DIR variable." OFF)
5858

src/plugins/intel_cpu/CMakeLists.txt

+5-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,9 @@ if(NOT X86_64)
206206
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/x64/*
207207
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/x64/*
208208
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/snippets/x64/*
209-
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/x64/*)
209+
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/tpp/x64/*
210+
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/x64/*
211+
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/tpp/x64/*)
210212
endif()
211213

212214
if (AARCH64)
@@ -218,7 +220,9 @@ endif()
218220

219221
if(NOT (AARCH64 OR ARM))
220222
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/arm/*
223+
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/tpp/aarch64/*
221224
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/aarch64/*
225+
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/tpp/aarch64/*
222226
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/executors/aarch64/*
223227
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/aarch64/*)
224228
endif()

src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2024 Intel Corporation
1+
// Copyright (C) 2024-2025 Intel Corporation
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

@@ -10,6 +10,7 @@
1010
#include "emitters/snippets/aarch64/jit_kernel_emitter.hpp"
1111
#include "emitters/snippets/aarch64/jit_loop_emitters.hpp"
1212
#include "emitters/snippets/aarch64/jit_memory_emitters.hpp"
13+
#include "emitters/snippets/cpu_kernel_executor_table.hpp"
1314
#include "emitters/snippets/cpu_runtime_configurator.hpp"
1415
#include "emitters/utils.hpp"
1516
#include "jit_snippets_emitters.hpp"
@@ -24,12 +25,17 @@
2425
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
2526
#include "transformations/snippets/common/op/fused_mul_add.hpp"
2627

28+
#ifdef SNIPPETS_LIBXSMM_TPP
29+
# include "emitters/tpp/aarch64/jit_brgemm_emitter.hpp"
30+
# include "transformations/tpp/common/op/brgemm.hpp"
31+
#endif
32+
2733
namespace ov {
2834

29-
#define CREATE_SNIPPETS_EMITTER(e_type) \
35+
#define CREATE_SNIPPETS_EMITTER(e_type, ...) \
3036
{ \
3137
[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
32-
return std::make_shared<e_type>(h.get(), isa, expr); \
38+
return std::make_shared<e_type>(h.get(), isa, expr, ##__VA_ARGS__); \
3339
}, \
3440
[](const std::shared_ptr<ov::Node>& n) -> std::set<std::vector<element::Type>> { \
3541
return e_type::get_supported_precisions(n); \
@@ -201,6 +207,12 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
201207
jitters[ov::intel_cpu::SwishNode::get_type_info_static()] = CREATE_CPU_EMITTER(jit_swish_emitter);
202208
jitters[ov::op::v0::Tanh::get_type_info_static()] = CREATE_CPU_EMITTER(jit_tanh_emitter);
203209

210+
#ifdef SNIPPETS_LIBXSMM_TPP
211+
// brgemm
212+
jitters[ov::intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] =
213+
CREATE_SNIPPETS_EMITTER(jit_brgemm_emitter, configurator->get_kernel_executor_table(), compiled_kernel_cache);
214+
#endif
215+
204216
// control flow
205217
jitters[snippets::op::KernelStatic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_kernel_static_emitter);
206218
jitters[snippets::op::KernelDynamic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_kernel_dynamic_emitter);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
// Copyright (C) 2020-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "brgemm_generic.hpp"
6+
7+
#include "common/utils.hpp"
8+
#include "dnnl_extension_utils.h"
9+
#include "utils/general_utils.h"
10+
11+
#define PRINT(X) ss << #X << " = " << X << "\n"
12+
#define EQ(X) X == rhs.X
13+
#define HASH(X) seed = dnnl::impl::hash_combine(seed, X)
14+
15+
namespace ov::intel_cpu {
16+
17+
bool BrgemmGenericKernelConfig::is_completed() const {
18+
return !one_of(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC) || is_empty();
19+
}
20+
21+
bool BrgemmGenericKernelConfig::is_empty() const {
22+
return everyone_is(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC, m_beta);
23+
}
24+
25+
bool BrgemmGenericKernelConfig::operator==(const BrgemmGenericKernelConfig& rhs) const {
26+
return EQ(m_beta) && EQ(m_M) && EQ(m_N) && EQ(m_K) && EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC);
27+
}
28+
29+
void BrgemmGenericKernelConfig::update(int64_t M,
30+
int64_t N,
31+
int64_t K,
32+
int64_t LDA,
33+
int64_t LDB,
34+
int64_t LDC,
35+
float beta) {
36+
// If M/N/K is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example)
37+
// To process this case, we have to make this Config as empty (nullify runtime parameters)
38+
if (one_of(0, M, N, K)) {
39+
m_M = 0;
40+
m_N = 0;
41+
m_K = 0;
42+
m_LDA = 0;
43+
m_LDB = 0;
44+
m_LDC = 0;
45+
m_beta = 0;
46+
} else {
47+
m_M = M;
48+
m_N = N;
49+
m_K = K;
50+
m_LDA = LDA;
51+
m_LDB = LDB;
52+
m_LDC = LDC;
53+
m_beta = beta;
54+
}
55+
}
56+
57+
size_t BrgemmGenericKernelConfig::compute_hash() const {
58+
size_t seed = 0;
59+
HASH(m_M);
60+
HASH(m_N);
61+
HASH(m_K);
62+
HASH(m_LDA);
63+
HASH(m_LDB);
64+
HASH(m_LDC);
65+
HASH(m_beta);
66+
return seed;
67+
}
68+
69+
#ifdef SNIPPETS_DEBUG_CAPS
70+
std::string BrgemmGenericKernelConfig::to_string() const {
71+
std::stringstream ss;
72+
PRINT(m_M);
73+
PRINT(m_N);
74+
PRINT(m_K);
75+
PRINT(m_LDA);
76+
PRINT(m_LDB);
77+
PRINT(m_LDC);
78+
PRINT(m_beta);
79+
return ss.str();
80+
}
81+
#endif
82+
83+
float BrgemmKernelExecutorHelper::get_beta(
84+
const ov::snippets::lowered::LoopManagerPtr& loop_manager,
85+
int loop_id,
86+
const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info) {
87+
// Find all Expanded loops with the same Unified loop information -> they were decomposed from this Unified Loop.
88+
// Note that LoopInfo are normalized and sorted (due to NormalizedLoopIDs pass).
89+
// It means that previous executed Loops have Loop ID less the current Loop ID.
90+
// - If there is executed Loop (work_amount > 0) and evaluated before the current -> the current Brgemm should have
91+
// `beta = 1`.
92+
// - If there is not this Loop -> the current executed Brgemm should have `beta = 0`.
93+
if (loop_id > 0) {
94+
const auto& current_unified_loop_info = current_expanded_loop_info->get_unified_loop_info();
95+
// Check the previous Loops
96+
--loop_id;
97+
while (loop_id >= 0) {
98+
const auto& expanded_loop_info =
99+
loop_manager->get_loop_info<ov::snippets::lowered::ExpandedLoopInfo>(loop_id);
100+
if (expanded_loop_info->get_unified_loop_info() != current_unified_loop_info) {
101+
return 0;
102+
}
103+
if (expanded_loop_info->get_work_amount() > 0) {
104+
// there is previous executed Brgemm with `beta = 0` -> the current Brgemm should have `beta = 1`
105+
return 1;
106+
}
107+
--loop_id;
108+
}
109+
}
110+
return 0;
111+
}
112+
113+
std::tuple<int64_t, int64_t, int64_t, float> BrgemmKernelExecutorHelper::get_runtime_brgemm_params(
114+
const ov::snippets::lowered::ExpressionPtr& expr,
115+
const ov::snippets::lowered::LinearIRCPtr& linear_ir) {
116+
const auto& input_pds = expr->get_input_port_descriptors();
117+
const auto& output_pds = expr->get_output_port_descriptors();
118+
OV_CPU_JIT_EMITTER_ASSERT((input_pds.size() == 2 || input_pds.size() == 3) && output_pds.size() == 1,
119+
"Invalid number of in/out port descriptors");
120+
121+
const auto& in0_shape = snippets::utils::get_planar_vdims(input_pds[0]->get_shape(), input_pds[0]->get_layout());
122+
const auto& in1_shape = snippets::utils::get_planar_vdims(input_pds[1]->get_shape(), input_pds[1]->get_layout());
123+
const auto& in0_subtensor = input_pds[0]->get_subtensor();
124+
const auto& in1_subtensor = input_pds[1]->get_subtensor();
125+
126+
// Need to update M, K, N
127+
// 1. If the original value in subtensor is `FULL_DIM`, it means that
128+
// Brgemm block should process full tensor by this dim -> take dimension from shape
129+
// 2. Otherwise, Brgemm block processes part of the tensor by this dim
130+
// (there is blocking by this dimension) -> take from Loop increment
131+
132+
auto M = *++in0_subtensor.rbegin();
133+
auto K = *in0_subtensor.rbegin();
134+
auto N = *in1_subtensor.rbegin();
135+
136+
size_t loop_idx = 0;
137+
const auto& loop_ids = expr->get_loop_ids();
138+
const auto& loop_manager = linear_ir->get_loop_manager();
139+
auto get_loop_info = [&]() {
140+
OPENVINO_ASSERT(loop_idx < loop_ids.size(), "Loop is missed");
141+
return loop_manager->get_loop_info<ov::snippets::lowered::ExpandedLoopInfo>(loop_ids[loop_idx++]);
142+
};
143+
144+
/* ------- Dimension M ----------*/
145+
if (ov::snippets::utils::is_full_dim_value(M)) {
146+
M = *++in0_shape.rbegin();
147+
} else {
148+
const auto& current_expanded_loop_info = get_loop_info();
149+
const auto& in_ports = current_expanded_loop_info->get_input_ports();
150+
const auto& out_ports = current_expanded_loop_info->get_output_ports();
151+
// Quick validation check: Should we check that port is really Brgemm port?
152+
// If BrgemmCopyB in the Loop by M -> first input port will be BrgemmCopyB with `incremented=false`
153+
// to avoid extra checks, we validate only first input port
154+
auto check_port = [&](const ov::snippets::lowered::LoopPort& p) {
155+
return p.get_dim_idx() == 1 && p.is_processed();
156+
};
157+
OPENVINO_ASSERT(
158+
in_ports.size() > 1 && check_port(in_ports[0]) && out_ports.size() == 1 && check_port(out_ports[0]),
159+
"Incorrect Loop by Brgemm dimension M");
160+
M = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0;
161+
input_pds[0]->set_subtensor_dim(1, M);
162+
output_pds[0]->set_subtensor_dim(1, M);
163+
}
164+
165+
/* ------- Dimension N ----------*/
166+
if (ov::snippets::utils::is_full_dim_value(N)) {
167+
N = *in1_shape.rbegin();
168+
} else {
169+
const auto& current_expanded_loop_info = get_loop_info();
170+
const auto& in_ports = current_expanded_loop_info->get_input_ports();
171+
const auto& out_ports = current_expanded_loop_info->get_output_ports();
172+
// Quick validation check: Should we check that port is really Brgemm port?
173+
auto check_port = [&](const ov::snippets::lowered::LoopPort& p) {
174+
return p.get_dim_idx() == 0 && p.is_processed();
175+
};
176+
OPENVINO_ASSERT(in_ports.size() >= 2 && !in_ports.front().is_processed() &&
177+
std::all_of(in_ports.cbegin() + 1, in_ports.cend(), check_port) && out_ports.size() == 1 &&
178+
check_port(out_ports.back()),
179+
"Incorrect Loop by Brgemm dimension N");
180+
N = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0;
181+
input_pds[1]->set_subtensor_dim(0, N);
182+
output_pds[0]->set_subtensor_dim(0, N);
183+
}
184+
185+
/* ------- Dimension K ----------*/
186+
// 1. If Brgemm block processes full dimension K -> `beta = 0`
187+
// 2. If Brgemm block processes part of the dimension K (there is blocking), need to find
188+
// the most first executed Brgemm Block in Loops which iterate through dimension K (work_amount > 0).
189+
// First of them will have `beta = 0`, other - `beta = 1`
190+
float beta = 0;
191+
if (ov::snippets::utils::is_full_dim_value(K)) {
192+
K = *in0_shape.rbegin();
193+
} else {
194+
const auto& current_expanded_loop_info = get_loop_info();
195+
const auto& in_ports = current_expanded_loop_info->get_input_ports();
196+
const auto& out_ports = current_expanded_loop_info->get_output_ports();
197+
// Quick validation check: Should we check that port is really Brgemm port?
198+
OPENVINO_ASSERT(in_ports.size() >= 2 && in_ports.front().get_dim_idx() == 0 &&
199+
in_ports.front().is_processed() && in_ports.back().get_dim_idx() == 1 &&
200+
in_ports.back().is_processed() && out_ports.size() == 1 &&
201+
!out_ports.front().is_processed(),
202+
"Incorrect Loop by Brgemm dimension K");
203+
K = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0;
204+
input_pds[0]->set_subtensor_dim(0, K);
205+
input_pds[1]->set_subtensor_dim(1, K);
206+
if (K > 0) {
207+
beta = get_beta(loop_manager, static_cast<int>(loop_ids.back()), current_expanded_loop_info);
208+
}
209+
}
210+
211+
return std::make_tuple(M, N, K, beta);
212+
}
213+
214+
#undef PRINT
215+
#undef EQ
216+
#undef HASH
217+
218+
} // namespace ov::intel_cpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (C) 2020-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "emitters/snippets/cpu_kernel_executor_table.hpp"
8+
#include "emitters/utils.hpp"
9+
#include "snippets/lowered/loop_info.hpp"
10+
#include "snippets/lowered/loop_manager.hpp"
11+
#include "utils/general_utils.h"
12+
13+
namespace ov::intel_cpu {
14+
15+
struct BrgemmGenericKernelConfig : public snippets::KernelExecutorBase::GenericConfig {
16+
public:
17+
BrgemmGenericKernelConfig() = default;
18+
19+
bool is_completed() const override;
20+
bool is_empty() const;
21+
22+
virtual void update(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, float beta);
23+
24+
bool operator==(const BrgemmGenericKernelConfig& rhs) const;
25+
bool operator!=(const BrgemmGenericKernelConfig& rhs) const {
26+
return !(*this == rhs);
27+
}
28+
29+
int64_t get_M() const {
30+
return m_M;
31+
}
32+
int64_t get_N() const {
33+
return m_N;
34+
}
35+
int64_t get_K() const {
36+
return m_K;
37+
}
38+
float get_beta() const {
39+
return m_beta;
40+
}
41+
int64_t get_LDA() const {
42+
return m_LDA;
43+
}
44+
int64_t get_LDB() const {
45+
return m_LDB;
46+
}
47+
int64_t get_LDC() const {
48+
return m_LDC;
49+
}
50+
51+
#ifdef SNIPPETS_DEBUG_CAPS
52+
std::string to_string() const override;
53+
#endif
54+
55+
protected:
56+
size_t compute_hash() const;
57+
58+
int64_t m_M{0}, m_N{0}, m_K{0}, m_LDA{0}, m_LDB{0}, m_LDC{0};
59+
float m_beta{0};
60+
};
61+
62+
class BrgemmKernelExecutorHelper {
63+
public:
64+
virtual ~BrgemmKernelExecutorHelper() = default;
65+
66+
static float get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager,
67+
int loop_id,
68+
const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info);
69+
70+
// This function returns M, N, K dimensions and beta of brgemm as a tuple, based on loop info in linear_ir.
71+
static std::tuple<int64_t, int64_t, int64_t, float> get_runtime_brgemm_params(
72+
const ov::snippets::lowered::ExpressionPtr& expr,
73+
const ov::snippets::lowered::LinearIRCPtr& linear_ir);
74+
};
75+
76+
} // namespace ov::intel_cpu

src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@
4444
# include "emitters/tpp/x64/jit_eltwise_emitters.hpp"
4545
# include "emitters/tpp/x64/jit_equation_emitter.hpp"
4646
# include "emitters/tpp/x64/jit_scalar_emitter.hpp"
47-
# include "transformations/tpp/x64/op/brgemm.hpp"
47+
# include "transformations/tpp/common/op/brgemm.hpp"
48+
# include "transformations/tpp/common/op/modifiers.hpp"
4849
# include "transformations/tpp/x64/op/eltwise.hpp"
4950
# include "transformations/tpp/x64/op/equation.hpp"
50-
# include "transformations/tpp/x64/op/modifiers.hpp"
5151
# include "transformations/tpp/x64/op/reduce.hpp"
5252
# include "transformations/tpp/x64/op/scalar.hpp"
5353
// Note: for reference implementations
@@ -295,7 +295,8 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho
295295
#endif
296296

297297
#ifdef SNIPPETS_LIBXSMM_TPP
298-
jitters[intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BrgemmTppEmitter);
298+
jitters[intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] =
299+
CREATE_SNIPPETS_EMITTER(BrgemmTppEmitter, configurator->get_kernel_executor_table(), compiled_kernel_cache);
299300
jitters[intel_cpu::tpp::op::Add::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter);
300301
jitters[intel_cpu::tpp::op::Subtract::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter);
301302
jitters[intel_cpu::tpp::op::Multiply::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter);

0 commit comments

Comments
 (0)