Skip to content

Commit 3a9464d

Browse files
authored
[Snippets] Support Brgemm with transposed_b via BrgemmCopyB (#24932)
### Details: - *Support FP32/BF16/I8 matmuls with transpose_b=true via BrgemmCopyB* - *BrgemmCopyB emitter: handle tail iteration by N before the main body* - *Remove workaround on LDB and N dim rounding in brgemm emitters and related buffers* ### Tickets: - *CVS-114487* ## TODO: - [ ] BufferAllocation test for FP32 brgemm with repacking - [ ] SetBrgemmCopyBBuffersShape tests - [ ] MHA with transpose B for low precisions (FP32 already exists) - [ ] FuseTransposeBrgemm tests
1 parent f48b30a commit 3a9464d

File tree

62 files changed

+1161
-965
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+1161
-965
lines changed

src/common/snippets/docs/mha_optimization_guide.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ For enhancing the execution efficiency, blocking across the M, K, and N matmul d
123123

124124
### Blocking Parameters
125125

126-
The heuristics for determining the optimal block sizes can be found in [SetBrgemmCPUBlockingParams](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp).
126+
The heuristics for determining the optimal block sizes can be found in [BrgemmCPUBlocking](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp).
127127

128128
**Please note: Blocking by M dimension is shared between both Brgemms. Please see [SplitLoops](../include/snippets/lowered/pass/split_loops.hpp) lowered pass for the details.**
129129

@@ -141,7 +141,7 @@ Based on previously discussed information, we provide the following recommendati
141141
In local experiments, some transformations might be worth to change:
142142
- Disable [ExtractUnsupportedTransposes](#extractunsupportedtransposes) transformation in order to benchmark Snippets Transpose implementation.
143143
- Adjust [SplitDimensionM](#splitdimensionm) heuristics in order to benchmark another splitting, or disable the pass at all.
144-
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `SetBrgemmCPUBlockingParams`.
144+
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `BrgemmCPUBlocking`.
145145
- Please note that there are 2 Matmul nodes inside a single MHA, and each Matmul can have his own optimal K, N blocking params.
146146
M block is better to keep the same since the corresponding blocking loop is shared between both Matmuls.
147147
- For the BF16/INT8 blocking loops, 2 options are possible: blocking can be done only for Brgemm node, or for BrgemmCopyB repacking too.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "snippets/itt.hpp"
8+
#include "snippets/lowered/loop_manager.hpp"
9+
#include "snippets/lowered/specific_loop_iter_handlers.hpp"
10+
#include "snippets/lowered/pass/iter_handler.hpp"
11+
#include "snippets/op/brgemm.hpp"
12+
13+
namespace ov {
14+
namespace snippets {
15+
namespace lowered {
16+
namespace pass {
17+
18+
/**
19+
* @interface BrgemmBlockingBase
20+
* @brief Base class for Brgemm blocking, which defines interface for blocking markup,
21+
* and contains default implementation
22+
* @ingroup snippets
23+
*/
24+
class BrgemmBlockingBase {
25+
public:
26+
static snippets::lowered::SpecificIterationHandlers get_default_blocking_loop_handlers(size_t work_amount, size_t block_size);
27+
28+
protected:
29+
/**
30+
* @interface get_blocking_params
31+
* @brief Computes optimal blocking params for current brgemm expression
32+
* @param brgemm_expr Brgemm expression
33+
* @return tuple in format (m_block, n_block, k_block)
34+
*/
35+
virtual std::tuple<size_t, size_t, size_t> get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr);
36+
/**
37+
* @interface mark_blocking_loops
38+
* @brief Covers brgemm with blocking loops. Also should calculate optimal blocking parameters inside.
39+
* @param linear_ir LIR that contains brgemm
40+
* @param brgemm_it iterator on brgemm expression which should be covered with blocking loops
41+
*/
42+
virtual bool mark_blocking_loops(snippets::lowered::LinearIR& linear_ir,
43+
const snippets::lowered::LinearIR::constExprIt& brgemm_it,
44+
size_t m_block,
45+
size_t n_block,
46+
size_t k_block);
47+
48+
static bool blocking_loop_exists(const snippets::lowered::LoopManagerPtr& loop_manager,
49+
const ov::snippets::lowered::ExpressionPtr& brgemm_expr);
50+
51+
static void mark_m_blocking(const snippets::lowered::LoopManagerPtr& loop_manager,
52+
snippets::lowered::LinearIR::constExprIt loop_begin,
53+
snippets::lowered::LinearIR::constExprIt loop_end,
54+
const std::vector<snippets::lowered::LoopPort>& entries,
55+
const std::vector<snippets::lowered::LoopPort>& exits,
56+
size_t block_size_m);
57+
58+
static void mark_n_blocking(const snippets::lowered::LoopManagerPtr& loop_manager,
59+
snippets::lowered::LinearIR::constExprIt loop_begin,
60+
snippets::lowered::LinearIR::constExprIt loop_end,
61+
const std::vector<snippets::lowered::LoopPort>& entries,
62+
const std::vector<snippets::lowered::LoopPort>& exits,
63+
size_t block_size_n);
64+
65+
static void mark_k_blocking(const snippets::lowered::LoopManagerPtr& loop_manager,
66+
snippets::lowered::LinearIR::constExprIt loop_begin,
67+
snippets::lowered::LinearIR::constExprIt loop_end,
68+
const std::vector<snippets::lowered::LoopPort>& entries,
69+
const std::vector<snippets::lowered::LoopPort>& exits,
70+
size_t block_size_k);
71+
};
72+
73+
/**
74+
* @interface BrgemmBlocking
75+
* @brief Base class for brgemm blocking passes
76+
* @ingroup snippets
77+
*/
78+
template <typename BRGEMM_TYPE,
79+
typename std::enable_if<std::is_base_of<ov::snippets::op::Brgemm, BRGEMM_TYPE>::value, bool>::type = true>
80+
class BrgemmBlocking : public snippets::lowered::pass::RangedPass, public BrgemmBlockingBase {
81+
public:
82+
OPENVINO_RTTI("BrgemmBlocking", "RangedPass")
83+
84+
bool run(snippets::lowered::LinearIR& linear_ir,
85+
snippets::lowered::LinearIR::constExprIt begin,
86+
snippets::lowered::LinearIR::constExprIt end) override final { // NOLINT
87+
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::BrgemmBlocking")
88+
const auto& loop_manager = linear_ir.get_loop_manager();
89+
bool modified = false;
90+
for (auto expr_it = begin; expr_it != end; expr_it++) {
91+
const auto& brgemm_expr = *expr_it;
92+
const auto brgemm = ov::as_type_ptr<BRGEMM_TYPE>(brgemm_expr->get_node());
93+
if (!brgemm)
94+
continue;
95+
OPENVINO_ASSERT(!blocking_loop_exists(loop_manager, brgemm_expr),
96+
"Brgemm mustn't be covered in loops before blocking pass");
97+
size_t m_block, n_block, k_block;
98+
std::tie(m_block, n_block, k_block) = get_blocking_params(brgemm_expr);
99+
modified = mark_blocking_loops(linear_ir, expr_it, m_block, n_block, k_block);
100+
}
101+
return modified;
102+
}
103+
};
104+
} // namespace pass
105+
} // namespace lowered
106+
} // namespace snippets
107+
} // namespace ov

src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp

+34
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,40 @@ class TransformInnerSplitLoop : public pass::RangedPass {
6464
size_t m_tail_size;
6565
};
6666

67+
/**
68+
* @interface SetEvaluateOnce
69+
* @brief The pass set `evaluate once = true` only to ExpandedLoopInfo which is mapped on LoopEnd in the passed iterator `end`.
70+
* The pointer arithmetic should be updated in the separate optimization `OptimizeLoopSingleEvaluation`
71+
* @ingroup snippets
72+
*/
73+
class SetEvaluateOnce : public snippets::lowered::pass::RangedPass {
74+
public:
75+
SetEvaluateOnce() = default;
76+
OPENVINO_RTTI("SetEvaluateOnce", "RangedPass")
77+
bool run(snippets::lowered::LinearIR& linear_ir,
78+
snippets::lowered::LinearIR::constExprIt begin,
79+
snippets::lowered::LinearIR::constExprIt end) override;
80+
std::shared_ptr<snippets::lowered::pass::PassBase> merge(const std::shared_ptr<snippets::lowered::pass::PassBase>& other) override;
81+
};
82+
83+
/**
84+
* @interface SetBrgemmBeta
85+
* @brief The pass updates all CPUBrgemm nodes with a new beta value
86+
* @param m_beta - beta which must be set
87+
* @ingroup snippets
88+
*/
89+
class SetBrgemmBeta : public snippets::lowered::pass::RangedPass {
90+
public:
91+
SetBrgemmBeta(float beta);
92+
OPENVINO_RTTI("SetBrgemmBeta", "RangedPass")
93+
bool run(snippets::lowered::LinearIR& linear_ir,
94+
snippets::lowered::LinearIR::constExprIt begin,
95+
snippets::lowered::LinearIR::constExprIt end) override;
96+
std::shared_ptr<snippets::lowered::pass::PassBase> merge(const std::shared_ptr<snippets::lowered::pass::PassBase>& other) override;
97+
98+
private:
99+
float m_beta = 0;
100+
};
67101
} // namespace pass
68102
} // namespace lowered
69103
} // namespace snippets

src/common/snippets/include/snippets/lowered/port_descriptor.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,14 @@ class PortDescriptor {
8282

8383
class PortDescriptorUtils {
8484
public:
85-
static void set_port_descriptor_ptr(const ov::Input<ov::Node>& n, const PortDescriptorPtr& desc);
86-
static void set_port_descriptor_ptr(const ov::Output<ov::Node>& n, const PortDescriptorPtr& desc);
87-
static void set_port_descriptor(const ov::Input<ov::Node>& n, std::vector<size_t> subtensor, std::vector<size_t> layout = {});
88-
static void set_port_descriptor(const ov::Output<ov::Node>& n, std::vector<size_t> subtensor, std::vector<size_t> layout = {});
85+
static void set_port_descriptor_ptr(const ov::Input<ov::Node>& in, const PortDescriptorPtr& desc);
86+
static void set_port_descriptor_ptr(const ov::Output<ov::Node>& out, const PortDescriptorPtr& desc);
87+
static void set_port_descriptor(const ov::Input<ov::Node>& in, std::vector<size_t> subtensor, std::vector<size_t> layout = {});
88+
static void set_port_descriptor(const ov::Output<ov::Node>& out, std::vector<size_t> subtensor, std::vector<size_t> layout = {});
8989

9090
static PortDescriptorPtr get_port_descriptor_ptr(const ov::Input<ov::Node>& in);
91-
static PortDescriptorPtr get_port_descriptor_ptr(const ov::Input<const ov::Node>& out);
92-
static PortDescriptorPtr get_port_descriptor_ptr(const ov::Output<ov::Node>& in);
91+
static PortDescriptorPtr get_port_descriptor_ptr(const ov::Input<const ov::Node>& in);
92+
static PortDescriptorPtr get_port_descriptor_ptr(const ov::Output<ov::Node>& out);
9393
static PortDescriptorPtr get_port_descriptor_ptr(const ov::Output<const ov::Node>& out);
9494

9595
static void clean(const std::shared_ptr<ov::Node>& node);

src/common/snippets/include/snippets/op/brgemm.hpp

+2-15
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,17 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op {
2222
OPENVINO_OP("Brgemm", "SnippetsOpset");
2323
Brgemm(const Output<Node>& A, const Output<Node>& B,
2424
const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu,
25-
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
26-
size_t blk_size_m = 0, size_t blk_size_k = 0, size_t blk_size_n = 0);
25+
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {});
2726
Brgemm(const Output<Node>& A, const Output<Node>& B,
2827
const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_c,
29-
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
30-
size_t blk_size_m = 0, size_t blk_size_k = 0, size_t blk_size_n = 0);
28+
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {});
3129
Brgemm() = default;
3230

3331
size_t get_offset_a() const { return get_input_offset(0); }
3432
size_t get_offset_b() const { return get_input_offset(1); }
3533
size_t get_offset_c() const { return get_output_offset(0); }
3634

37-
size_t get_m_block_size() const { return m_M_blk; }
38-
size_t get_k_block_size() const { return m_K_blk; }
39-
size_t get_n_block_size() const { return m_N_blk; }
4035
float get_beta() const { return m_beta; }
41-
42-
void set_m_block_size(size_t block_size) { m_M_blk = block_size; }
43-
void set_k_block_size(size_t block_size) { m_K_blk = block_size; }
44-
void set_n_block_size(size_t block_size) { m_N_blk = block_size; }
4536
void set_beta(float beta) { m_beta = beta; }
4637

4738
static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1);
@@ -57,10 +48,6 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op {
5748
std::vector<ov::PartialShape> get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const;
5849
ov::PartialShape infer_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const;
5950
ov::PartialShape get_planar_output_shape(const ov::PartialShape& output_shape) const;
60-
void set_block_size_values(size_t blk_size_m, size_t blk_size_k, size_t blk_size_n);
61-
size_t m_M_blk = 0;
62-
size_t m_K_blk = 0;
63-
size_t m_N_blk = 0;
6451
float m_beta = 0.f;
6552

6653
private:

src/common/snippets/include/snippets/pass/matmul_to_brgemm.hpp

-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ class MatMulToBrgemm: public ov::pass::MatcherPass {
2222
public:
2323
OPENVINO_RTTI("MatMulToBrgemm", "0");
2424
MatMulToBrgemm();
25-
26-
private:
27-
void init_ports(const std::shared_ptr<op::Brgemm>& brgemm) const;
2825
};
2926

3027

src/common/snippets/include/snippets/utils/utils.hpp

+9
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ bool broadcast_merge_dim(size_t& dst, const size_t& d1, const size_t& d2);
127127
VectorDims pshape_to_vdims(const PartialShape&);
128128
ov::PartialShape vdims_to_pshape(const VectorDims&);
129129

130+
inline size_t dimension_to_size_t(const ov::Dimension& dim) {
131+
return dim.is_dynamic() ? snippets::utils::get_dynamic_value<VectorDims::value_type>() : static_cast<size_t>(dim.get_length());
132+
}
133+
130134
// dim_idx starts from the layout end: dim_idx = 0 -> last element in layout (layout.back())
131135
inline size_t get_input_dim_idx(const std::vector<size_t>& layout, size_t dim_idx) {
132136
OPENVINO_ASSERT(dim_idx < layout.size(), "Incorrect dim_idx");
@@ -214,6 +218,11 @@ VectorDims get_planar_vdims(const snippets::lowered::ExpressionPort& expr_port);
214218
* @return preordered shape: `shape[i]` = `planar_shape[order[i]]` where `shape` is shape before applying the order.
215219
*/
216220
VectorDims get_preordered_vdims(const snippets::lowered::ExpressionPort& expr_port);
221+
/**
222+
* @brief Returns subtensor projected on current shape: FULL_DIM subtensor values are replaced with actual shape value
223+
* @param expr_port Port whose subtensor should be processed
224+
*/
225+
VectorDims get_projected_subtensor(const snippets::lowered::ExpressionPort& expr_port);
217226
/* --------------------------- */
218227

219228
/**

0 commit comments

Comments
 (0)