Skip to content

Commit e5f3811

Browse files
committed
doc: graph: add document for gated mlp fusion
1 parent 6f98283 commit e5f3811

File tree

3 files changed

+123
-0
lines changed

3 files changed

+123
-0
lines changed

doc/graph/complex_fusion/gated_mlp.md

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
Gated Multi-Layer Perceptron (Gated-MLP) {#dev_guide_graph_gated_mlp}
2+
=====================================================================
3+
4+
## Overview
5+
6+
Gated Multi-Layer Perceptron (Gated-MLP) is a variant of MLP which is widely
7+
used as the Feed Forward Network (FFN) in many Transformer-based Large Language
8+
Models (LLMs).
9+
10+
Typically, the FFN in Transformer architecture [1] is defined as a two layer MLP
11+
with a ReLU activation in between which can be replaced with other activations.
12+
13+
\f[
14+
15+
FFN(src,W,V) = ReLU(src \cdot W) \cdot V
16+
17+
\f]
18+
19+
Gated Linear Unit (GLU) is adopted to replace the first linear layer to
20+
improve the quality of Transformer-based models [2]:
21+
22+
\f[
23+
24+
GLU(src,W_1,W_2) = (src \cdot W_1) \otimes Sigmoid(src \cdot W_2) \\
25+
26+
FFN(src,W_1,W_2,V) = GLU(src,W_1,W_2) \cdot V
27+
28+
\f]
29+
30+
Where the \f$ src \cdot W_1 \f$ is usually called "FC (fully-connected) up",
31+
\f$ src \cdot W_2 \f$ is called "FC gate", and the last linear is called
32+
"FC down".
33+
34+
Swish activation is further adopted to replace Sigmoid in the GLU to form
35+
swiGLU.
36+
37+
\f[
38+
39+
Swish(x) = x \otimes Sigmoid(x) \\
40+
41+
swiGLU(src,W_1,W_2) = (src \cdot W_1) \otimes Swish(src \cdot W_2) \\
42+
43+
FFN(src,W_1,W_2,V) = swiGLU(src,W_1,W_2) \cdot V
44+
45+
\f]
46+
47+
The Gated-MLP based on swiGLU is also adopted in LLMs like LLaMA [3], Qwen [4],
48+
etc.
49+
50+
## Gated-MLP patterns
51+
52+
oneDNN supports Gated-MLP and its optimization through Graph API [5] by defining
53+
the graph, getting partition from the graph, and optimizing the kernels
54+
underneath. In general, a Gated-MLP pattern is defined as a directional acyclic
55+
graph (DAG) using oneDNN Graph API.
56+
57+
### Floating-point Gated-MLP
58+
59+
oneDNN defines floating-point (f32, bf16, and f16) Gated-MLP as follows. The blue
60+
nodes are required when defining a Gated-MLP pattern while the brown nodes are
61+
optional.
62+
63+
![Gated-MLP pattern](images/fp-gated-mlp.png)
64+
65+
1. The first MatMul on the top left calculates "FC up": \f$ src \cdot W_1 \f$.
66+
See [MatMul](@ref dev_guide_op_matmul) operation in Graph API.
67+
2. The second MatMul on the top right calculates "FC gate": \f$ src \cdot W_2 \f$.
68+
3. The Activation node is optional. If required, it can be constructed with the
69+
activation operations in Graph API, for example, [ReLU](@ref dev_guide_op_relu),
70+
[GELU](@ref dev_guide_op_gelu), [Sigmoid](@ref dev_guide_op_sigmoid), and so on.
71+
For Swish activation, the node can be constructed with the [Sigmoid](@ref dev_guide_op_sigmoid)
72+
and [Multiply](@ref dev_guide_op_multiply) as below. You can also refer the
73+
[Gated-MLP example](https://github.com/oneapi-src/oneDNN/tree/main/examples/graph/gated_mlp.cpp)
74+
for Swish definition.
75+
76+
![Swish Activation](images/gated-mlp-swish.png)
77+
78+
4. The last MatMul on the bottom performs the "FC down" operation between the
79+
GLU output and \f$V\f$.
80+
81+
## Data Types
82+
83+
oneDNN supports the floating-point Gated-MLP pattern with data types f32, bf16,
84+
and f16. You can specify the data type via the input and output data type fields
85+
of logical tensors for each operation. oneDNN does not support mixing different
86+
floating data types in a floating-point Gated-MLP pattern.
87+
88+
The definition of the data types and support status on different CPU and GPU
89+
platforms follow the general description in @ref dev_guide_data_types.
90+
91+
## Implementation limitations
92+
93+
1. oneDNN primitive-based Gated-MLP is implemented as the reference
94+
implementation on both Intel Architecture Processors and Intel Graphics
95+
Products. In this case, floating-point Gated-MLP patterns are usually
96+
implemented with three f32, bf16, or f16 matmul (with binary or eltwise
97+
post-ops) primitives.
98+
2. The Gated-MLP patterns functionally supports all input shapes meeting the
99+
shape requirements of each operation in the graph. For example, the `MatMul`
100+
operation requires shape consistency for `k` dimension. The `Multiply`
101+
operation requires the input tensors to have the same shape or the shapes can
102+
be properly broadcasted based on the operation attribute.
103+
104+
## Examples
105+
106+
oneDNN provides a [Gated-MLP
107+
example](https://github.com/oneapi-src/oneDNN/tree/main/examples/graph/gated_mlp.cpp)
108+
demonstrating how to construct a typical floating-point Gated-MLP pattern with
109+
oneDNN Graph API on CPU and GPU with different runtimes.
110+
111+
For applications where the weights of FC up and FC gate are combined as a single
112+
tensor, oneDNN also provides an
113+
[example](https://github.com/oneapi-src/oneDNN/tree/main/examples/graph/gated_mlp_wei_combined.cpp)
114+
demonstrating how to create the weight tensors for the pattern with the offsets
115+
and strides from the combined weight tensor.
116+
117+
## References
118+
119+
1. Attention is all you need, https://arxiv.org/abs/1706.03762v7
120+
2. GLU Variants Improve Transformer, https://arxiv.org/abs/2002.05202
121+
3. LLaMA: Open and Efficient Foundation Language Models, https://arxiv.org/abs/2302.13971
122+
4. Qwen Technical Report, https://arxiv.org/abs/2309.16609
123+
5. oneDNN Graph API documentation, https://oneapi-src.github.io/oneDNN/graph_extension.html
18.8 KB
Loading
Loading

0 commit comments

Comments
 (0)