-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
graph: fix the intermediate data types in SDPA patterns #2894
base: main
Are you sure you want to change the base?
Changes from all commits
7a0285f
d32cdd2
79f3e18
d55039e
fdea8ad
abe08ff
80f4f02
df61ffb
76f2cf9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,13 +54,17 @@ DNNL_GRAPH_OP_SCHEMA(Add, 1, | |
.set_num_inputs(2) | ||
.set_num_outputs(1) | ||
.set_commutative_inputs() | ||
.set_input(0, "src_0", "T") | ||
.set_input(1, "src_1", "T") | ||
.set_output(0, "dst", "T") | ||
.set_input(0, "src_0", "T1") | ||
.set_input(1, "src_1", "T2") | ||
.set_output(0, "dst", "T3") | ||
.set_attr(op_attr::auto_broadcast, false, attribute_kind::s, | ||
"numpy", {"none", "numpy"}) | ||
.set_type_constraints( | ||
"T", {data_type::f32, data_type::bf16, data_type::f16}) | ||
"T1", {data_type::f32, data_type::bf16, data_type::f16}) | ||
.set_type_constraints( | ||
"T2", {data_type::f32, data_type::bf16, data_type::f16}) | ||
.set_type_constraints( | ||
"T3", {data_type::f32, data_type::bf16, data_type::f16}) | ||
.set_shape_inference_function( | ||
infer_elemwise_arithmetic_output_shape)) | ||
|
||
|
@@ -684,10 +688,13 @@ DNNL_GRAPH_OP_SCHEMA(MatMul, 1, | |
.set_input(0, "src", "T") | ||
.set_input(1, "weights", "T") | ||
.set_input(2, "bias", "T") | ||
.set_output(0, "dst", "T") | ||
.set_output(0, "dst", "T1") | ||
.set_type_constraints( | ||
"T", {data_type::f32, data_type::bf16, data_type::f16}) | ||
.set_type_constraints( | ||
"T1", {data_type::f32, data_type::bf16, data_type::f16}) | ||
.set_shape_inference_function(infer_matmul_output_shape) | ||
.set_op_def_constraint_function(check_matmul_dtype) | ||
.SET_MATMUL_COMMON_ATTRS) | ||
|
||
DNNL_GRAPH_OP_SCHEMA(Maximum, 1, | ||
|
@@ -788,9 +795,6 @@ DNNL_GRAPH_OP_SCHEMA(MishBackward, 1, | |
"T", {data_type::f32, data_type::bf16, data_type::f16}) | ||
.set_shape_inference_function(infer_identity_output_shape)) | ||
|
||
// TODO(Yixin): for Multiply. input and output needs to have the same dtypes | ||
// But in current pytorch bridge's type promotion system, there's no | ||
// such constraints. So this feature is postponed. | ||
DNNL_GRAPH_OP_SCHEMA(Multiply, 1, | ||
op_schema_t() | ||
.set_num_inputs(2) | ||
|
@@ -1029,12 +1033,15 @@ DNNL_GRAPH_OP_SCHEMA(SoftMax, 1, | |
op_schema_t() | ||
.set_num_inputs(1) | ||
.set_num_outputs(1) | ||
.set_input(0, "src", "T") | ||
.set_output(0, "dst", "T") | ||
.set_input(0, "src", "T1") | ||
.set_output(0, "dst", "T2") | ||
.set_attr(op_attr::axis, false, attribute_kind::i, (int64_t)1) | ||
.set_type_constraints( | ||
"T", {data_type::f32, data_type::bf16, data_type::f16}) | ||
.set_shape_inference_function(infer_identity_output_shape)) | ||
"T1", {data_type::f32, data_type::bf16, data_type::f16}) | ||
.set_type_constraints( | ||
"T2", {data_type::f32, data_type::bf16, data_type::f16}) | ||
.set_shape_inference_function(infer_identity_output_shape) | ||
.set_op_def_constraint_function(check_softmax_dtype)) | ||
|
||
DNNL_GRAPH_OP_SCHEMA(SoftMaxBackward, 1, | ||
op_schema_t() | ||
|
@@ -1121,13 +1128,17 @@ DNNL_GRAPH_OP_SCHEMA(Subtract, 1, | |
op_schema_t() | ||
.set_num_inputs(2) | ||
.set_num_outputs(1) | ||
.set_input(0, "src_0", "T") | ||
.set_input(1, "src_1", "T") | ||
.set_output(0, "dst", "T") | ||
.set_input(0, "src_0", "T1") | ||
.set_input(1, "src_1", "T2") | ||
.set_output(0, "dst", "T3") | ||
.set_attr(op_attr::auto_broadcast, false, attribute_kind::s, | ||
"numpy", {"none", "numpy"}) | ||
.set_type_constraints( | ||
"T", {data_type::f32, data_type::bf16, data_type::f16}) | ||
"T1", {data_type::f32, data_type::bf16, data_type::f16}) | ||
.set_type_constraints( | ||
"T2", {data_type::f32, data_type::bf16, data_type::f16}) | ||
.set_type_constraints( | ||
"T3", {data_type::f32, data_type::bf16, data_type::f16}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This requires some documentation about type promotion as users might wonder what happens for example with f16 <- f16 + bf16. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, we don't allow f16 + bf16. It's mentioned in the "supported data types" section in the op document. When src0 and src1 have different data types, one of them should be f32 and the other one (f16 or bf16) will be promoted to f32 for calculation. |
||
.set_shape_inference_function( | ||
infer_elemwise_arithmetic_output_shape)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure we are aligned. This still allows to use f32 values to store f16/bf16 data, as long as we respect roundings to f16/bf16 accuracy, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in my understanding, it's still allowed for backend implementations. From this perspective, it seems I need to keep the original statement. My intention here was to align the implementations. As the original statement sounds like different backends (eg. DNNL & GC, CPU & GPU) can have different numerical behaviors.