Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: fix the intermediate data types in SDPA patterns #2894

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
27 changes: 16 additions & 11 deletions src/graph/interface/op_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,17 @@ DNNL_GRAPH_OP_SCHEMA(Add, 1,
.set_num_inputs(2)
.set_num_outputs(1)
.set_commutative_inputs()
.set_input(0, "src_0", "T")
.set_input(1, "src_1", "T")
.set_output(0, "dst", "T")
.set_input(0, "src_0", "T1")
.set_input(1, "src_1", "T2")
.set_output(0, "dst", "T3")
.set_attr(op_attr::auto_broadcast, false, attribute_kind::s,
"numpy", {"none", "numpy"})
.set_type_constraints(
"T", {data_type::f32, data_type::bf16, data_type::f16})
"T1", {data_type::f32, data_type::bf16, data_type::f16})
.set_type_constraints(
"T2", {data_type::f32, data_type::bf16, data_type::f16})
.set_type_constraints(
"T3", {data_type::f32, data_type::bf16, data_type::f16})
.set_shape_inference_function(
infer_elemwise_arithmetic_output_shape))

Expand Down Expand Up @@ -791,9 +795,6 @@ DNNL_GRAPH_OP_SCHEMA(MishBackward, 1,
"T", {data_type::f32, data_type::bf16, data_type::f16})
.set_shape_inference_function(infer_identity_output_shape))

// TODO(Yixin): for Multiply. input and output needs to have the same dtypes
// But in current pytorch bridge's type promotion system, there's no
// such constraints. So this feature is postponed.
DNNL_GRAPH_OP_SCHEMA(Multiply, 1,
op_schema_t()
.set_num_inputs(2)
Expand Down Expand Up @@ -1127,13 +1128,17 @@ DNNL_GRAPH_OP_SCHEMA(Subtract, 1,
op_schema_t()
.set_num_inputs(2)
.set_num_outputs(1)
.set_input(0, "src_0", "T")
.set_input(1, "src_1", "T")
.set_output(0, "dst", "T")
.set_input(0, "src_0", "T1")
.set_input(1, "src_1", "T2")
.set_output(0, "dst", "T3")
.set_attr(op_attr::auto_broadcast, false, attribute_kind::s,
"numpy", {"none", "numpy"})
.set_type_constraints(
"T", {data_type::f32, data_type::bf16, data_type::f16})
"T1", {data_type::f32, data_type::bf16, data_type::f16})
.set_type_constraints(
"T2", {data_type::f32, data_type::bf16, data_type::f16})
.set_type_constraints(
"T3", {data_type::f32, data_type::bf16, data_type::f16})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires some documentation about type promotion as users might wonder what happens for example with f16 <- f16 + bf16.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, we don't allow f16 + bf16. It's mentioned in the "supported data types" section in the op document. When src0 and src1 have different data types, one of them should be f32 and the other one (f16 or bf16) will be promoted to f32 for calculation.

.set_shape_inference_function(
infer_elemwise_arithmetic_output_shape))

Expand Down