Skip to content

Commit ce32bfe

Browse files
[Snippets][Docs] MHA Optimization Guide (#25493)
Co-authored-by: Ivan Novoselov <ivan.novoselov@intel.com>
1 parent 2bcb1e3 commit ce32bfe

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# MHA Optimization Guide
2+
3+
## Introduction
4+
5+
This guide explores the mechanism of the Multi Head Attention (MHA) patterns tokenization and several methods that are used for MHA performance optimization.
6+
Also, there is provided several recommendations on how to fine-tune performance of the specific MHA pattern.
7+
8+
## MHA Tokenization
9+
10+
This structure represents the basic MHA pattern that can be tokenized by Snippets:
11+
12+
```mermaid
13+
graph TB
14+
MM0A[Transpose] --> MatMul0
15+
MM0B[Transpose/Eltwise/FakeQuantize] --> MatMul0
16+
MatMul0 --> IntermediateBeforeSM[Transpose/Eltwise/Select/Reshape/FakeQuantize]
17+
IntermediateBeforeSM --> Softmax
18+
Softmax --> IntermediateAfterSM[Transpose/Eltwise/Select/Reshape/FakeQuantize]
19+
IntermediateAfterSM --> MatMul1
20+
MM1B[Transpose] --> MatMul1
21+
MatMul1 --> OpAfterMM2[Transpose/Eltwise/FakeQuantize]
22+
```
23+
24+
The main layers in MHA pattern are `MatMul0`, `Softmax` and `MatMul1`. Other layers are optional.
25+
Please note, that layers, denoted by `/`, can represent both single nodes and sequences of nodes.
26+
The code, which performs the tokenization, is placed in [TokenizeMHASnippets](../src/pass/mha_tokenization.cpp) transformation.
27+
28+
### CPU Plugin Callback for MHA Tokenization
29+
30+
The tokenization pass can be adjusted via callback
31+
In CPU plugin, the callback disables tokenization in 3 types of cases:
32+
33+
1. Operations that are not supported by Snippets CPU backend.
34+
For example, because fusing is not expected to bring sufficient optimization opportunities.
35+
2. Operations skipped deliberately to allow for plugin-specific fusings.
36+
For example, elementwise operations that follow Convolution nodes are skipped because the eltwises will be fused into Convolutions by the CPU plugin.
37+
3. Operations that are not tokenized for performance reasons: executing MHA operations one-by-one may be faster in some cases.
38+
39+
The CPU plugin callback for TokenizeMHASnippets is placed in [transformation_pipeline.cpp](../../../plugins/intel_cpu/src/transformations/transformation_pipeline.cpp) file (please see the code in `MainSnippets` method).
40+
41+
**Please note that the CPU callback is usually ignored in cpu functional tests: SnippetsMode::IgnoreCallback is used for that**.
42+
Currently, SnippetsMode has 3 states: `Enable`, `IgnoreCallback` and `Disable`.
43+
For the details, please refer to [ov::intel_cpu::Config](../../../plugins/intel_cpu/src/config.h).
44+
45+
## Snippets Common Optimizations
46+
47+
After tokenization, snippets [common optimizations](../src/pass/common_optimizations.cpp) are applied to the tokenized Subgraphs.
48+
These transformations can modify both the Subgraph's body and its surroundings (e.g. extract constant nodes outside the Subgraph).
49+
Let's explore several transformations that can impact MHA performance.
50+
51+
### ExtractUnsupportedTransposes
52+
53+
[ExtractUnsupportedTransposes](../src/pass/extract_unsupported_transposes.cpp) moves up unsupported Transposes outside the Subgraph.
54+
55+
Snippets support 2 types of Transposes:
56+
57+
1. Transposes which are fused into Brgemm (which supports strided read/write) node by [FuseTransposeBrgemm](../src/pass/fuse_transpose_brgemm.cpp) data flow transformation.
58+
The supported Transpose orders for Brgemm fusion are defined by `TokenizeMHASnippets::get_fusion_transpose_order` in [mha_tokenization.cpp](../src/pass/mha_tokenization.cpp)
59+
2. The rest of transposes are decomposed by [TransposeDecomposition](../src/pass/transpose_decomposition.cpp) data flow transformation.
60+
The supported by decomposition Transpose orders are defined by `TokenizeMHASnippets::get_decomposed_transpose_order` in [mha_tokenization.cpp](../src/pass/mha_tokenization.cpp)
61+
62+
**Please note: the "unsupported" Transpose actually can be executed via Snippets decomposition, however CPU plugin implementation is expected to work faster in this particular case.**
63+
64+
### SplitDimensionM
65+
66+
[SplitDimensionM](../src/pass/split_dimension_m.cpp) splits M dimension of MHA in 2 parts (`batch_m` and `new_m`) by inserting Reshape on A input of the first Matmul and output of the second Matmul (the rest Subgraph's inputs are reshaped by Unsqueeze-like reshapes in order not to break subgraph semantic).
67+
This optimization increases parallel work amount by `batch_m` times thus enabling a more efficient parallel execution in some cases.
68+
The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::get_splited_dimensions` method.
69+
70+
Let's consider an example of the transformation:
71+
72+
```mermaid
73+
graph LR
74+
subgraph left[" "]
75+
direction TB
76+
MM0A_1[Matmul0 Input A]
77+
MM0B_1[Matmul0 Input B]
78+
MM1B_1[Matmul1 Input B]
79+
S_1["Subgraph"]
80+
MM1C_1[Matmul1 Output]
81+
82+
MM0A_1-->|"[1, M, K1]"|S_1
83+
MM0B_1-->|"[1, K1, N1]"|S_1
84+
MM1B_1-->|"[1, N1, N2]"|S_1
85+
S_1-->|"[1, M, N2]"|MM1C_1
86+
end
87+
subgraph middle[" "]
88+
direction TB
89+
MM0A_2[Matmul0 Input A]
90+
Reshape1["Input SplitM Reshape"]
91+
MM0B_2[Matmul0 Input B]
92+
Reshape2["Unsqueeze-like Reshape 1"]
93+
MM1B_2[Matmul1 Input B]
94+
Reshape3["Unsqueeze-like Reshape 2"]
95+
S_2["Subgraph"]
96+
Reshape4["Output SplitM Reshape"]
97+
MM1C_2[Matmul1 Output]
98+
99+
MM0A_2-->|"[1, M, K1]"|Reshape1
100+
Reshape1-->|"[1, batch_M, new_M, K1]"|S_2
101+
MM0B_2-->|"[1, K1, N1]"|Reshape2
102+
Reshape2-->|"[1, 1, K1, N1]"|S_2
103+
MM1B_2-->|"[1, N1, N2]"|Reshape3
104+
Reshape3-->|"[1, 1, N1, N2]"|S_2
105+
S_2-->|"[1, batch_M, new_M, N2]"|Reshape4
106+
Reshape4-->|"[1, M, N2]"|MM1C_2
107+
end
108+
left-->|SplitDimensionM|middle
109+
%% middle-->|<font size=+1>Attach Add\n to Subgraph</font>|right
110+
classDef no-bg-color fill:none,stroke-width:0px
111+
class left,middle,right no-bg-color
112+
```
113+
114+
**Important notes:**
115+
- Since `SplitDimensionM` depends on parallel concurrency, the transformation result depends not only on the HW platform, but on number of streams used during model inference as well.
116+
For instance, this might lead to different result in throughput and latency hint modes.
117+
- `SplitDimensionM::can_be_optimized` is used in CPU plugin callback: if this method reports that appropriate parallel work amount can not be set for the MHA, the tokenization doesn't happen.
118+
119+
## Brgemm Blocking
120+
121+
Within the Snippets CPU backend, the MatMul is executed using the Brgemm primitive.
122+
For enhancing the execution efficiency, blocking across the M, K, and N matmul dimensions is used.
123+
124+
### Blocking Parameters
125+
126+
The heuristics for determining the optimal block sizes can be found in [SetBrgemmCPUBlockingParams](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp).
127+
128+
**Please note: Blocking by M dimension is shared between both Brgemms. Please see [SplitLoops](../include/snippets/lowered/pass/split_loops.hpp) lowered pass for the details.**
129+
130+
### Blocking Order
131+
132+
The lowered pass [BrgemmBlocking](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp) performs blocking loops creation on LinearIR.
133+
Currently, the order of blocking loops is following (from outer to inner): `M->N->K`.
134+
135+
## MHA Performance Tuning Recommendations
136+
137+
Based on previously discussed information, we provide the following recommendations for the MHA performance fine-tuning:
138+
139+
1. Check if there are MHA's which were not tokenized because of [CPU plugin callback](#cpu-plugin-callback-for-mha-tokenization).
140+
2. Check how the graph was changed by [CommonOptimizations](#snippets-common-optimizations).
141+
In local experiments, some transformations might be worth to change:
142+
- Disable [ExtractUnsupportedTransposes](#extractunsupportedtransposes) transformation in order to benchmark Snippets Transpose implementation.
143+
- Adjust [SplitDimensionM](#splitdimensionm) heuristics in order to benchmark another splitting, or disable the pass at all.
144+
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `SetBrgemmCPUBlockingParams`.
145+
- Please note that there are 2 Matmul nodes inside a single MHA, and each Matmul can have his own optimal K, N blocking params.
146+
M block is better to keep the same since the corresponding blocking loop is shared between both Matmuls.
147+
- For the BF16/INT8 blocking loops, 2 options are possible: blocking can be done only for Brgemm node, or for BrgemmCopyB repacking too.
148+
149+
Following these recommendations, the performance of some specific MHA patters can be fine-tuned.
150+
Additionally, the results of these experiments can be used as a solid foundation for the subsequent heuristics adjustments.

0 commit comments

Comments
 (0)