Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: fix the intermediate data types in SDPA patterns #2894

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions doc/graph/fusion_patterns/sdpa.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ platforms follow the general description in @ref dev_guide_data_types.
4. GPU
- Optimized implementation is available for 4D Q/K/V tensors with shape
defined as (N, H, S, D).
- Optimized implementation is available for floating-point SDPA with `f16`
data type and `D <= 256` on Intel Graphics Products with Intel(R) Xe Matrix
Extensions (Intel(R) XMX) support.
- Optimized implementation is available for `f16` or `bf16` SDPA with `f32`
intermediate data type and `D <= 256` on Intel Graphics Products with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does intermediate data type mean when users construct graph? Do we need to describe it in text or picture?

Intel(R) Xe Matrix Extensions (Intel(R) XMX) support.

## Example

Expand Down
12 changes: 7 additions & 5 deletions doc/graph/operations/Add.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is

Add operation supports the following data type combinations.

| Src_0 / Src_1 | Dst |
|:--------------|:-----|
| f32 | f32 |
| bf16 | bf16 |
| f16 | f16 |
| Src_0 | Src_1 | Dst |
|:----------|:----------|:-----|
| f32 | f32 | f32 |
| bf16 | bf16 | bf16 |
| f16 | f16 | f16 |
| f32 | bf16, f16 | f32 |
| bf16, f16 | f32 | f32 |
12 changes: 7 additions & 5 deletions doc/graph/operations/Divide.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is

Divide operation supports the following data type combinations.

| Src_0 / Src_1 | Dst |
|:--------------|:-----|
| f32 | f32 |
| bf16 | bf16 |
| f16 | f16 |
| Src_0 | Src_1 | Dst |
|:----------|:----------|:-----|
| f32 | f32 | f32 |
| bf16 | bf16 | bf16 |
| f16 | f16 | f16 |
| f32 | bf16, f16 | f32 |
| bf16, f16 | f32 | f32 |
10 changes: 5 additions & 5 deletions doc/graph/operations/MatMul.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ constructing an operation.

MatMul operation supports the following data type combinations.

| Src | Weights | Bias | Dst |
|:-----|:--------|:-----|:-----|
| f32 | f32 | f32 | f32 |
| bf16 | bf16 | bf16 | bf16 |
| f16 | f16 | f16 | f16 |
| Src | Weights | Bias | Dst |
|:-----|:--------|:-----|:----------|
| f32 | f32 | f32 | f32 |
| bf16 | bf16 | bf16 | f32, bf16 |
| f16 | f16 | f16 | f32, f16 |
12 changes: 7 additions & 5 deletions doc/graph/operations/Multiply.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is

Multiply operation supports the following data type combinations.

| Src_0 / Src_1 | Dst |
|:--------------|:-----|
| f32 | f32 |
| bf16 | bf16 |
| f16 | f16 |
| Src_0 | Src_1 | Dst |
|:----------|:----------|:-----|
| f32 | f32 | f32 |
| bf16 | bf16 | bf16 |
| f16 | f16 | f16 |
| f32 | bf16, f16 | f32 |
| bf16, f16 | f32 | f32 |
10 changes: 5 additions & 5 deletions doc/graph/operations/Softmax.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ constructing an operation.

SoftMax operation supports the following data type combinations.

| Src | Dst |
|:-----|:-----|
| f32 | f32 |
| bf16 | bf16 |
| f16 | f16 |
| Src | Dst |
|:-----|:----------------|
| f32 | f32, bf16, f16 |
| bf16 | bf16 |
| f16 | f16 |
12 changes: 7 additions & 5 deletions doc/graph/operations/Subtract.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is

Subtract operation supports the following data type combinations.

| Src_0 / Src_1 | Dst |
|:--------------|:-----|
| f32 | f32 |
| bf16 | bf16 |
| f16 | f16 |
| Src_0 | Src_1 | Dst |
|:----------|:----------|:-----|
| f32 | f32 | f32 |
| bf16 | bf16 | bf16 |
| f16 | f16 | f16 |
| f32 | bf16, f16 | f32 |
| bf16, f16 | f32 | f32 |
3 changes: 1 addition & 2 deletions doc/graph/programming_model/low_precision.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ Graph operations support bf16 and f16 data types.

A TypeCast operation performing down conversion should be inserted clearly to
indicate the use of low numeric precision. oneDNN Graph implementation fully
honors the API-specified numeric precision and only performs the computation
using the API-specified or higher numeric precision.
honors the API-specified numeric precision.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure we are aligned. This still allows to use f32 values to store f16/bf16 data, as long as we respect roundings to f16/bf16 accuracy, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in my understanding, it's still allowed for backend implementations. From this perspective, it seems I need to keep the original statement. My intention here was to align the implementations. As the original statement sounds like different backends (eg. DNNL & GC, CPU & GPU) can have different numerical behaviors.


@img{bf16_programming.jpg,Figure 2: Overview of bf16 programming model.,80%,}
Loading