Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: fix the intermediate data types in SDPA patterns #2894

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

TaoLv
Copy link
Contributor

@TaoLv TaoLv commented Mar 17, 2025

Address MFDNN-13091

The purpose is to align the intermediate data types in the pattern with those in ATen SDPA and GPU ukernel SDPA.

  1. matmul/softmax/binary ops are extended to support mixed data types for inputs and outputs.
  2. The two SDPA examples are changed to use f32 intermediate data type, rather than bf16 or f16.
  3. The backend is changed to dispatch to GPU ukernel SDPA only when f32 intermediate data type is required.
  4. Add test cases for f16/bf16 SDPA with f32 intermediate data type.

With these, now a typical f16 SDPA which can be dispatched to GPU ukernel SDPA looks like below:

image

@TaoLv TaoLv requested review from a team as code owners March 17, 2025 08:19
@github-actions github-actions bot added documentation A request to change/fix/improve the documentation. Codeowner: @oneapi-src/onednn-doc component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch component:examples labels Mar 17, 2025
@TaoLv TaoLv force-pushed the lvtao/main/fix-sdpa-intermediates branch 2 times, most recently from 257f3f9 to 4a506ff Compare March 17, 2025 08:39
@TaoLv TaoLv force-pushed the lvtao/main/fix-sdpa-intermediates branch from 4a506ff to 2cdb502 Compare March 17, 2025 09:20
@TaoLv TaoLv force-pushed the lvtao/main/fix-sdpa-intermediates branch from 2cdb502 to 5b05972 Compare March 21, 2025 06:50
@TaoLv TaoLv force-pushed the lvtao/main/fix-sdpa-intermediates branch from 5b05972 to 76f2cf9 Compare March 21, 2025 13:12
@@ -52,7 +52,6 @@ Graph operations support bf16 and f16 data types.

A TypeCast operation performing down conversion should be inserted clearly to
indicate the use of low numeric precision. oneDNN Graph implementation fully
honors the API-specified numeric precision and only performs the computation
using the API-specified or higher numeric precision.
honors the API-specified numeric precision.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure we are aligned. This still allows to use f32 values to store f16/bf16 data, as long as we respect roundings to f16/bf16 accuracy, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in my understanding, it's still allowed for backend implementations. From this perspective, it seems I need to keep the original statement. My intention here was to align the implementations. As the original statement sounds like different backends (eg. DNNL & GC, CPU & GPU) can have different numerical behaviors.

.set_type_constraints(
"T2", {data_type::f32, data_type::bf16, data_type::f16})
.set_type_constraints(
"T3", {data_type::f32, data_type::bf16, data_type::f16})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires some documentation about type promotion as users might wonder what happens for example with f16 <- f16 + bf16.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, we don't allow f16 + bf16. It's mentioned in the "supported data types" section in the op document. When src0 and src1 have different data types, one of them should be f32 and the other one (f16 or bf16) will be promoted to f32 for calculation.

Copy link
Contributor

@dzarukin dzarukin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Move approve responsibility to @mgouicem. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:examples component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch documentation A request to change/fix/improve the documentation. Codeowner: @oneapi-src/onednn-doc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants