|
| 1 | +# MHA Optimization Guide |
| 2 | + |
| 3 | +## Introduction |
| 4 | + |
| 5 | +This guide explores the mechanism of the Multi Head Attention (MHA) patterns tokenization and several methods that are used for MHA performance optimization. |
| 6 | +Also, there is provided several recommendations on how to fine-tune performance of the specific MHA pattern. |
| 7 | + |
| 8 | +## MHA Tokenization |
| 9 | + |
| 10 | +This structure represents the basic MHA pattern that can be tokenized by Snippets: |
| 11 | + |
| 12 | +```mermaid |
| 13 | +graph TB |
| 14 | + MM0A[Transpose] --> MatMul0 |
| 15 | + MM0B[Transpose/Eltwise/FakeQuantize] --> MatMul0 |
| 16 | + MatMul0 --> IntermediateBeforeSM[Transpose/Eltwise/Select/Reshape/FakeQuantize] |
| 17 | + IntermediateBeforeSM --> Softmax |
| 18 | + Softmax --> IntermediateAfterSM[Transpose/Eltwise/Select/Reshape/FakeQuantize] |
| 19 | + IntermediateAfterSM --> MatMul1 |
| 20 | + MM1B[Transpose] --> MatMul1 |
| 21 | + MatMul1 --> OpAfterMM2[Transpose/Eltwise/FakeQuantize] |
| 22 | +``` |
| 23 | + |
| 24 | +The main layers in MHA pattern are `MatMul0`, `Softmax` and `MatMul1`. Other layers are optional. |
| 25 | +Please note, that layers, denoted by `/`, can represent both single nodes and sequences of nodes. |
| 26 | +The code, which performs the tokenization, is placed in [TokenizeMHASnippets](../src/pass/mha_tokenization.cpp) transformation. |
| 27 | + |
| 28 | +### CPU Plugin Callback for MHA Tokenization |
| 29 | + |
| 30 | +The tokenization pass can be adjusted via callback |
| 31 | +In CPU plugin, the callback disables tokenization in 3 types of cases: |
| 32 | + |
| 33 | +1. Operations that are not supported by Snippets CPU backend. |
| 34 | +For example, because fusing is not expected to bring sufficient optimization opportunities. |
| 35 | +2. Operations skipped deliberately to allow for plugin-specific fusings. |
| 36 | +For example, elementwise operations that follow Convolution nodes are skipped because the eltwises will be fused into Convolutions by the CPU plugin. |
| 37 | +3. Operations that are not tokenized for performance reasons: executing MHA operations one-by-one may be faster in some cases. |
| 38 | + |
| 39 | +The CPU plugin callback for TokenizeMHASnippets is placed in [transformation_pipeline.cpp](../../../plugins/intel_cpu/src/transformations/transformation_pipeline.cpp) file (please see the code in `MainSnippets` method). |
| 40 | + |
| 41 | +**Please note that the CPU callback is usually ignored in cpu functional tests: SnippetsMode::IgnoreCallback is used for that**. |
| 42 | +Currently, SnippetsMode has 3 states: `Enable`, `IgnoreCallback` and `Disable`. |
| 43 | +For the details, please refer to [ov::intel_cpu::Config](../../../plugins/intel_cpu/src/config.h). |
| 44 | + |
| 45 | +## Snippets Common Optimizations |
| 46 | + |
| 47 | +After tokenization, snippets [common optimizations](../src/pass/common_optimizations.cpp) are applied to the tokenized Subgraphs. |
| 48 | +These transformations can modify both the Subgraph's body and its surroundings (e.g. extract constant nodes outside the Subgraph). |
| 49 | +Let's explore several transformations that can impact MHA performance. |
| 50 | + |
| 51 | +### ExtractUnsupportedTransposes |
| 52 | + |
| 53 | +[ExtractUnsupportedTransposes](../src/pass/extract_unsupported_transposes.cpp) moves up unsupported Transposes outside the Subgraph. |
| 54 | + |
| 55 | +Snippets support 2 types of Transposes: |
| 56 | + |
| 57 | +1. Transposes which are fused into Brgemm (which supports strided read/write) node by [FuseTransposeBrgemm](../src/pass/fuse_transpose_brgemm.cpp) data flow transformation. |
| 58 | +The supported Transpose orders for Brgemm fusion are defined by `TokenizeMHASnippets::get_fusion_transpose_order` in [mha_tokenization.cpp](../src/pass/mha_tokenization.cpp) |
| 59 | +2. The rest of transposes are decomposed by [TransposeDecomposition](../src/pass/transpose_decomposition.cpp) data flow transformation. |
| 60 | +The supported by decomposition Transpose orders are defined by `TokenizeMHASnippets::get_decomposed_transpose_order` in [mha_tokenization.cpp](../src/pass/mha_tokenization.cpp) |
| 61 | + |
| 62 | +**Please note: the "unsupported" Transpose actually can be executed via Snippets decomposition, however CPU plugin implementation is expected to work faster in this particular case.** |
| 63 | + |
| 64 | +### SplitDimensionM |
| 65 | + |
| 66 | +[SplitDimensionM](../src/pass/split_dimension_m.cpp) splits M dimension of MHA in 2 parts (`batch_m` and `new_m`) by inserting Reshape on A input of the first Matmul and output of the second Matmul (the rest Subgraph's inputs are reshaped by Unsqueeze-like reshapes in order not to break subgraph semantic). |
| 67 | +This optimization increases parallel work amount by `batch_m` times thus enabling a more efficient parallel execution in some cases. |
| 68 | +The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::get_splited_dimensions` method. |
| 69 | + |
| 70 | +Let's consider an example of the transformation: |
| 71 | + |
| 72 | +```mermaid |
| 73 | + graph LR |
| 74 | + subgraph left[" "] |
| 75 | + direction TB |
| 76 | + MM0A_1[Matmul0 Input A] |
| 77 | + MM0B_1[Matmul0 Input B] |
| 78 | + MM1B_1[Matmul1 Input B] |
| 79 | + S_1["Subgraph"] |
| 80 | + MM1C_1[Matmul1 Output] |
| 81 | +
|
| 82 | + MM0A_1-->|"[1, M, K1]"|S_1 |
| 83 | + MM0B_1-->|"[1, K1, N1]"|S_1 |
| 84 | + MM1B_1-->|"[1, N1, N2]"|S_1 |
| 85 | + S_1-->|"[1, M, N2]"|MM1C_1 |
| 86 | + end |
| 87 | + subgraph middle[" "] |
| 88 | + direction TB |
| 89 | + MM0A_2[Matmul0 Input A] |
| 90 | + Reshape1["Input SplitM Reshape"] |
| 91 | + MM0B_2[Matmul0 Input B] |
| 92 | + Reshape2["Unsqueeze-like Reshape 1"] |
| 93 | + MM1B_2[Matmul1 Input B] |
| 94 | + Reshape3["Unsqueeze-like Reshape 2"] |
| 95 | + S_2["Subgraph"] |
| 96 | + Reshape4["Output SplitM Reshape"] |
| 97 | + MM1C_2[Matmul1 Output] |
| 98 | +
|
| 99 | + MM0A_2-->|"[1, M, K1]"|Reshape1 |
| 100 | + Reshape1-->|"[1, batch_M, new_M, K1]"|S_2 |
| 101 | + MM0B_2-->|"[1, K1, N1]"|Reshape2 |
| 102 | + Reshape2-->|"[1, 1, K1, N1]"|S_2 |
| 103 | + MM1B_2-->|"[1, N1, N2]"|Reshape3 |
| 104 | + Reshape3-->|"[1, 1, N1, N2]"|S_2 |
| 105 | + S_2-->|"[1, batch_M, new_M, N2]"|Reshape4 |
| 106 | + Reshape4-->|"[1, M, N2]"|MM1C_2 |
| 107 | + end |
| 108 | + left-->|SplitDimensionM|middle |
| 109 | +%% middle-->|<font size=+1>Attach Add\n to Subgraph</font>|right |
| 110 | + classDef no-bg-color fill:none,stroke-width:0px |
| 111 | + class left,middle,right no-bg-color |
| 112 | +``` |
| 113 | + |
| 114 | +**Important notes:** |
| 115 | +- Since `SplitDimensionM` depends on parallel concurrency, the transformation result depends not only on the HW platform, but on number of streams used during model inference as well. |
| 116 | +For instance, this might lead to different result in throughput and latency hint modes. |
| 117 | +- `SplitDimensionM::can_be_optimized` is used in CPU plugin callback: if this method reports that appropriate parallel work amount can not be set for the MHA, the tokenization doesn't happen. |
| 118 | + |
| 119 | +## Brgemm Blocking |
| 120 | + |
| 121 | +Within the Snippets CPU backend, the MatMul is executed using the Brgemm primitive. |
| 122 | +For enhancing the execution efficiency, blocking across the M, K, and N matmul dimensions is used. |
| 123 | + |
| 124 | +### Blocking Parameters |
| 125 | + |
| 126 | +The heuristics for determining the optimal block sizes can be found in [SetBrgemmCPUBlockingParams](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp). |
| 127 | + |
| 128 | +**Please note: Blocking by M dimension is shared between both Brgemms. Please see [SplitLoops](../include/snippets/lowered/pass/split_loops.hpp) lowered pass for the details.** |
| 129 | + |
| 130 | +### Blocking Order |
| 131 | + |
| 132 | +The lowered pass [BrgemmBlocking](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp) performs blocking loops creation on LinearIR. |
| 133 | +Currently, the order of blocking loops is following (from outer to inner): `M->N->K`. |
| 134 | + |
| 135 | +## MHA Performance Tuning Recommendations |
| 136 | + |
| 137 | +Based on previously discussed information, we provide the following recommendations for the MHA performance fine-tuning: |
| 138 | + |
| 139 | +1. Check if there are MHA's which were not tokenized because of [CPU plugin callback](#cpu-plugin-callback-for-mha-tokenization). |
| 140 | +2. Check how the graph was changed by [CommonOptimizations](#snippets-common-optimizations). |
| 141 | +In local experiments, some transformations might be worth to change: |
| 142 | + - Disable [ExtractUnsupportedTransposes](#extractunsupportedtransposes) transformation in order to benchmark Snippets Transpose implementation. |
| 143 | + - Adjust [SplitDimensionM](#splitdimensionm) heuristics in order to benchmark another splitting, or disable the pass at all. |
| 144 | +3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `SetBrgemmCPUBlockingParams`. |
| 145 | + - Please note that there are 2 Matmul nodes inside a single MHA, and each Matmul can have his own optimal K, N blocking params. |
| 146 | + M block is better to keep the same since the corresponding blocking loop is shared between both Matmuls. |
| 147 | + - For the BF16/INT8 blocking loops, 2 options are possible: blocking can be done only for Brgemm node, or for BrgemmCopyB repacking too. |
| 148 | + |
| 149 | +Following these recommendations, the performance of some specific MHA patters can be fine-tuned. |
| 150 | +Additionally, the results of these experiments can be used as a solid foundation for the subsequent heuristics adjustments. |
0 commit comments