@@ -44,22 +44,36 @@ optional.
44
44
MatMul with a scaling factor. It can be constructed by [ Multiply] (@ref dev_guide_op_multiply)
45
45
or [ Divide] (@ref dev_guide_op_divide) operation in Graph API. The scaling
46
46
factor is given by users as an input of SDPA. \f$\sqrt{d_k}\f$ in the formula
47
- is not considered as part of the SDPA pattern as it is constant.
47
+ is not considered as a part of the SDPA pattern because it is a constant.
48
48
3 . The Mask node is optional and is used to apply an attention mask to the
49
- output of the previous Scale node. It can be constructed by [ Add] (@ref dev_guide_op_add)
49
+ output of the previous Scale node. There are two types of masks that can
50
+ be applied:
51
+
52
+ 1 . Explicit user-generated mask: You can explicitly create a mask tensor
53
+ and pass it to the library for the computation of SDPA. In this case, mask
54
+ can be constructed by [ Add] (@ref dev_guide_op_add)
50
55
or [ Select] (@ref dev_guide_op_select) operation in Graph API for different
51
- mask policies (eg. causal mask or padding mask). When Add operation is used
52
- to apply the mask, the input mask is usually an upper triangular matrix with
53
- all the elements above the diagonal filled with ` -inf ` and zeroes elsewhere.
54
- The ` -inf ` entries will become zero probability after Softmax is applied in
55
- the next step. Alternately, a Select operation may be used. In this case, the
56
- input is a boolean tensor (for example, with ` true ` on and below the
57
- diagonal, and ` false ` above the diagonal). A ` false ` element in the mask
58
- forces the corresponding element of the scaled output to ` -inf ` , while a
59
- ` true ` element leaves it unchanged.
56
+ mask policies (for example, causal mask or padding mask). When the
57
+ Add operation is used to apply the mask, the input mask is usually an upper
58
+ triangular matrix with all the elements above the diagonal filled with
59
+ ` -inf ` and zeroes elsewhere. The ` -inf ` entries will become zero probability
60
+ after Softmax is applied in the next step.
61
+ Alternatively, a Select operation may be used. In this case, the
62
+ input is a boolean tensor (for example, with the boolean value set to ` true `
63
+ on and below the diagonal, and ` false ` above the diagonal).
64
+ A ` false ` element in the mask forces the corresponding element of the scaled
65
+ output to ` -inf ` , while a ` true ` element leaves it unchanged.
60
66
61
67
![ SDPA-mask-1] ( images/sdpa-mask-1.png ) ![ SDPA-mask-2] ( images/sdpa-mask-2.png )
62
68
69
+ 2 . Implicit library-generated mask: You can use the operations in the library
70
+ to generate a mask by constructing a subgraph. Currently, Graph API supports
71
+ generating an implicit causal mask (top-left aligned) using operations of
72
+ [ GenIndex] (@ref dev_guide_op_genindex), [ GreaterEqual] (@ref dev_guide_op_greaterequal)
73
+ and [ Select] (@ref dev_guide_op_select).
74
+
75
+ ![ SDPA-mask-3] ( images/sdpa-mask-3.png )
76
+
63
77
4 . The SoftMax operation takes the masked output and transforms it into
64
78
probabilities between 0 and 1. See [ SoftMax] (@ref dev_guide_op_softmax)
65
79
operation in Graph API.
@@ -97,7 +111,8 @@ platforms follow the general description in @ref dev_guide_data_types.
97
111
softmax primitives. The reference implementation requires memory to store the
98
112
intermediate results of the dot products between Query and Key which takes
99
113
\f$O(S^2)\f$ memory. It may lead to out-of-memory error when computing long
100
- sequence length input on platforms with limited memory.
114
+ sequence length input on platforms with limited memory. For an implicit
115
+ causal mask, the reference implementation is only available on CPU.
101
116
2 . The SDPA patterns functionally supports all input shapes meeting the shape
102
117
requirements of each operation in the graph. For example, Add, Multiply,
103
118
Divide, and Select operations require the input tensors to have the same
0 commit comments