-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
graph: fix the intermediate data types in SDPA patterns #2894
base: main
Are you sure you want to change the base?
Conversation
257f3f9
to
4a506ff
Compare
4a506ff
to
2cdb502
Compare
2cdb502
to
5b05972
Compare
5b05972
to
76f2cf9
Compare
@@ -52,7 +52,6 @@ Graph operations support bf16 and f16 data types. | |||
|
|||
A TypeCast operation performing down conversion should be inserted clearly to | |||
indicate the use of low numeric precision. oneDNN Graph implementation fully | |||
honors the API-specified numeric precision and only performs the computation | |||
using the API-specified or higher numeric precision. | |||
honors the API-specified numeric precision. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure we are aligned. This still allows to use f32 values to store f16/bf16 data, as long as we respect roundings to f16/bf16 accuracy, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in my understanding, it's still allowed for backend implementations. From this perspective, it seems I need to keep the original statement. My intention here was to align the implementations. As the original statement sounds like different backends (eg. DNNL & GC, CPU & GPU) can have different numerical behaviors.
.set_type_constraints( | ||
"T2", {data_type::f32, data_type::bf16, data_type::f16}) | ||
.set_type_constraints( | ||
"T3", {data_type::f32, data_type::bf16, data_type::f16}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This requires some documentation about type promotion as users might wonder what happens for example with f16 <- f16 + bf16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, we don't allow f16 + bf16. It's mentioned in the "supported data types" section in the op document. When src0 and src1 have different data types, one of them should be f32 and the other one (f16 or bf16) will be promoted to f32 for calculation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Move approve responsibility to @mgouicem. :)
Address MFDNN-13091
The purpose is to align the intermediate data types in the pattern with those in ATen SDPA and GPU ukernel SDPA.
With these, now a typical f16 SDPA which can be dispatched to GPU ukernel SDPA looks like below: