|
| 1 | +Gated Multi-Layer Perceptron (Gated-MLP) {#dev_guide_graph_gated_mlp} |
| 2 | +===================================================================== |
| 3 | + |
| 4 | +## Overview |
| 5 | + |
| 6 | +Gated Multi-Layer Perceptron (Gated-MLP) is a variant of MLP which is widely |
| 7 | +used as the Feed Forward Network (FFN) in many Transformer-based Large Language |
| 8 | +Models (LLMs). |
| 9 | + |
| 10 | +Typically, the FFN in Transformer architecture [1] is defined as a two layer MLP |
| 11 | +with a ReLU activation in between which can be replaced with other activations. |
| 12 | + |
| 13 | +\f[ |
| 14 | + |
| 15 | + FFN(src,W,V) = ReLU(src \cdot W) \cdot V |
| 16 | + |
| 17 | +\f] |
| 18 | + |
| 19 | +Gated Linear Unit (GLU) is adopted to replace the first linear layer to |
| 20 | +improve the quality of Transformer-based models [2]: |
| 21 | + |
| 22 | +\f[ |
| 23 | + |
| 24 | + GLU(src,W_1,W_2) = (src \cdot W_1) \otimes Sigmoid(src \cdot W_2) \\ |
| 25 | + |
| 26 | + FFN(src,W_1,W_2,V) = GLU(src,W_1,W_2) \cdot V |
| 27 | + |
| 28 | +\f] |
| 29 | + |
| 30 | +Where the \f$ src \cdot W_1 \f$ is usually called "FC (fully-connected) up", |
| 31 | +\f$ src \cdot W_2 \f$ is called "FC gate", and the last linear is called |
| 32 | +"FC down". |
| 33 | + |
| 34 | +Swish activation is further adopted to replace Sigmoid in the GLU to form |
| 35 | +swiGLU. |
| 36 | + |
| 37 | +\f[ |
| 38 | + |
| 39 | + Swish(x) = x \otimes Sigmoid(x) \\ |
| 40 | + |
| 41 | + swiGLU(src,W_1,W_2) = (src \cdot W_1) \otimes Swish(src \cdot W_2) \\ |
| 42 | + |
| 43 | + FFN(src,W_1,W_2,V) = swiGLU(src,W_1,W_2) \cdot V |
| 44 | + |
| 45 | +\f] |
| 46 | + |
| 47 | +The Gated-MLP based on swiGLU is also adopted in LLMs like LLaMA [3], Qwen [4], |
| 48 | +etc. |
| 49 | + |
| 50 | +## Gated-MLP patterns |
| 51 | + |
| 52 | +oneDNN supports Gated-MLP and its optimization through Graph API [5] by defining |
| 53 | +the graph, getting partition from the graph, and optimizing the kernels |
| 54 | +underneath. In general, a Gated-MLP pattern is defined as a directional acyclic |
| 55 | +graph (DAG) using oneDNN Graph API. |
| 56 | + |
| 57 | +### Floating-point Gated-MLP |
| 58 | + |
| 59 | +oneDNN defines floating-point (f32, bf16, and f16) Gated-MLP as follows. The blue |
| 60 | +nodes are required when defining a Gated-MLP pattern while the brown nodes are |
| 61 | +optional. |
| 62 | + |
| 63 | + |
| 64 | + |
| 65 | +1. The first MatMul on the top left calculates "FC up": \f$ src \cdot W_1 \f$. |
| 66 | + See [MatMul](@ref dev_guide_op_matmul) operation in Graph API. |
| 67 | +2. The second MatMul on the top right calculates "FC gate": \f$ src \cdot W_2 \f$. |
| 68 | +3. The Activation node is optional. If required, it can be constructed with the |
| 69 | + activation operations in Graph API, for example, [ReLU](@ref dev_guide_op_relu), |
| 70 | + [GELU](@ref dev_guide_op_gelu), [Sigmoid](@ref dev_guide_op_sigmoid), and so on. |
| 71 | + For Swish activation, the node can be constructed with the [Sigmoid](@ref dev_guide_op_sigmoid) |
| 72 | + and [Multiply](@ref dev_guide_op_multiply) as below. You can also refer the |
| 73 | + [Gated-MLP example](https://github.com/oneapi-src/oneDNN/tree/main/examples/graph/gated_mlp.cpp) |
| 74 | + for Swish definition. |
| 75 | + |
| 76 | +  |
| 77 | + |
| 78 | +4. The last MatMul on the bottom performs the "FC down" operation between the |
| 79 | + GLU output and \f$V\f$. |
| 80 | + |
| 81 | +## Data Types |
| 82 | + |
| 83 | +oneDNN supports the floating-point Gated-MLP pattern with data types f32, bf16, |
| 84 | +and f16. You can specify the data type via the input and output data type fields |
| 85 | +of logical tensors for each operation. oneDNN does not support mixing different |
| 86 | +floating data types in a floating-point Gated-MLP pattern. |
| 87 | + |
| 88 | +The definition of the data types and support status on different CPU and GPU |
| 89 | +platforms follow the general description in @ref dev_guide_data_types. |
| 90 | + |
| 91 | +## Implementation limitations |
| 92 | + |
| 93 | +1. oneDNN primitive-based Gated-MLP is implemented as the reference |
| 94 | + implementation on both Intel Architecture Processors and Intel Graphics |
| 95 | + Products. In this case, floating-point Gated-MLP patterns are usually |
| 96 | + implemented with three f32, bf16, or f16 matmul (with binary or eltwise |
| 97 | + post-ops) primitives. |
| 98 | +2. The Gated-MLP patterns functionally supports all input shapes meeting the |
| 99 | + shape requirements of each operation in the graph. For example, the `MatMul` |
| 100 | + operation requires shape consistency for `k` dimension. The `Multiply` |
| 101 | + operation requires the input tensors to have the same shape or the shapes can |
| 102 | + be properly broadcasted based on the operation attribute. |
| 103 | + |
| 104 | +## Examples |
| 105 | + |
| 106 | +oneDNN provides a [Gated-MLP |
| 107 | +example](https://github.com/oneapi-src/oneDNN/tree/main/examples/graph/gated_mlp.cpp) |
| 108 | +demonstrating how to construct a typical floating-point Gated-MLP pattern with |
| 109 | +oneDNN Graph API on CPU and GPU with different runtimes. |
| 110 | + |
| 111 | +For applications where the weights of FC up and FC gate are combined as a single |
| 112 | +tensor, oneDNN also provides an |
| 113 | +[example](https://github.com/oneapi-src/oneDNN/tree/main/examples/graph/gated_mlp_wei_combined.cpp) |
| 114 | +demonstrating how to create the weight tensors for the pattern with the offsets |
| 115 | +and strides from the combined weight tensor. |
| 116 | + |
| 117 | +## References |
| 118 | + |
| 119 | +1. Attention is all you need, https://arxiv.org/abs/1706.03762v7 |
| 120 | +2. GLU Variants Improve Transformer, https://arxiv.org/abs/2002.05202 |
| 121 | +3. LLaMA: Open and Efficient Foundation Language Models, https://arxiv.org/abs/2302.13971 |
| 122 | +4. Qwen Technical Report, https://arxiv.org/abs/2309.16609 |
| 123 | +5. oneDNN Graph API documentation, https://oneapi-src.github.io/oneDNN/graph_extension.html |
0 commit comments