Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: fix the intermediate data types in SDPA patterns #2894

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions doc/graph/fusion_patterns/sdpa.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ platforms follow the general description in @ref dev_guide_data_types.
4. GPU
- Optimized implementation is available for 4D Q/K/V tensors with shape
defined as (N, H, S, D).
- Optimized implementation is available for floating-point SDPA with `f16`
data type and `D <= 256` on Intel Graphics Products with Intel(R) Xe Matrix
Extensions (Intel(R) XMX) support.
- Optimized implementation is available for `f16` or `bf16` SDPA with `f32`
intermediate data type and `D <= 256` on Intel Graphics Products with
Intel(R) Xe Matrix Extensions (Intel(R) XMX) support.

## Example

Expand Down
12 changes: 7 additions & 5 deletions doc/graph/operations/Add.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is

Add operation supports the following data type combinations.

| Src_0 / Src_1 | Dst |
|:--------------|:-----|
| f32 | f32 |
| bf16 | bf16 |
| f16 | f16 |
| Src_0 | Src_1 | Dst |
|:----------|:----------|:-----|
| f32 | f32 | f32 |
| bf16 | bf16 | bf16 |
| f16 | f16 | f16 |
| f32 | bf16, f16 | f32 |
| bf16, f16 | f32 | f32 |
12 changes: 7 additions & 5 deletions doc/graph/operations/Divide.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is

Divide operation supports the following data type combinations.

| Src_0 / Src_1 | Dst |
|:--------------|:-----|
| f32 | f32 |
| bf16 | bf16 |
| f16 | f16 |
| Src_0 | Src_1 | Dst |
|:----------|:----------|:-----|
| f32 | f32 | f32 |
| bf16 | bf16 | bf16 |
| f16 | f16 | f16 |
| f32 | bf16, f16 | f32 |
| bf16, f16 | f32 | f32 |
10 changes: 5 additions & 5 deletions doc/graph/operations/MatMul.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ constructing an operation.

MatMul operation supports the following data type combinations.

| Src | Weights | Bias | Dst |
|:-----|:--------|:-----|:-----|
| f32 | f32 | f32 | f32 |
| bf16 | bf16 | bf16 | bf16 |
| f16 | f16 | f16 | f16 |
| Src | Weights | Bias | Dst |
|:-----|:--------|:-----|:----------|
| f32 | f32 | f32 | f32 |
| bf16 | bf16 | bf16 | f32, bf16 |
| f16 | f16 | f16 | f32, f16 |
12 changes: 7 additions & 5 deletions doc/graph/operations/Multiply.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is

Multiply operation supports the following data type combinations.

| Src_0 / Src_1 | Dst |
|:--------------|:-----|
| f32 | f32 |
| bf16 | bf16 |
| f16 | f16 |
| Src_0 | Src_1 | Dst |
|:----------|:----------|:-----|
| f32 | f32 | f32 |
| bf16 | bf16 | bf16 |
| f16 | f16 | f16 |
| f32 | bf16, f16 | f32 |
| bf16, f16 | f32 | f32 |
10 changes: 5 additions & 5 deletions doc/graph/operations/Softmax.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ constructing an operation.

SoftMax operation supports the following data type combinations.

| Src | Dst |
|:-----|:-----|
| f32 | f32 |
| bf16 | bf16 |
| f16 | f16 |
| Src | Dst |
|:-----|:----------------|
| f32 | f32, bf16, f16 |
| bf16 | bf16 |
| f16 | f16 |
12 changes: 7 additions & 5 deletions doc/graph/operations/Subtract.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is

Subtract operation supports the following data type combinations.

| Src_0 / Src_1 | Dst |
|:--------------|:-----|
| f32 | f32 |
| bf16 | bf16 |
| f16 | f16 |
| Src_0 | Src_1 | Dst |
|:----------|:----------|:-----|
| f32 | f32 | f32 |
| bf16 | bf16 | bf16 |
| f16 | f16 | f16 |
| f32 | bf16, f16 | f32 |
| bf16, f16 | f32 | f32 |
3 changes: 1 addition & 2 deletions doc/graph/programming_model/low_precision.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ Graph operations support bf16 and f16 data types.

A TypeCast operation performing down conversion should be inserted clearly to
indicate the use of low numeric precision. oneDNN Graph implementation fully
honors the API-specified numeric precision and only performs the computation
using the API-specified or higher numeric precision.
honors the API-specified numeric precision.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure we are aligned. This still allows to use f32 values to store f16/bf16 data, as long as we respect roundings to f16/bf16 accuracy, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in my understanding, it's still allowed for backend implementations. From this perspective, it seems I need to keep the original statement. My intention here was to align the implementations. As the original statement sounds like different backends (eg. DNNL & GC, CPU & GPU) can have different numerical behaviors.


@img{bf16_programming.jpg,Figure 2: Overview of bf16 programming model.,80%,}
24 changes: 16 additions & 8 deletions examples/graph/sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt,
// Create dnnl::stream.
dnnl::stream strm(eng);

// Intermediate data type
const memory::data_type dt_inter = memory::data_type::f32;

// Prepare input and output shapes to construct the sdpa graph.
const memory::dims q_sz = {p.mb, p.head_num, p.query_num, p.head_size};
const memory::dims k_sz = {p.mb, p.head_num, p.head_size, p.seq_len};
Expand All @@ -110,9 +113,10 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt,
// All combined in a single matmul primitive.
auto query_md = memory::desc(q_sz, dt, memory::format_tag::abcd);
auto key_md = memory::desc(k_sz, dt, memory::format_tag::abdc);
auto score_md = memory::desc(score_sz, dt, memory::format_tag::abcd);
auto score_md = memory::desc(score_sz, dt_inter, memory::format_tag::abcd);
auto scale_md = memory::desc(scale_sz, dt, memory::format_tag::abcd);
auto mask_md = memory::desc(mask_sz, dt, memory::format_tag::abcd);
auto probs_md = memory::desc(score_sz, dt, memory::format_tag::abcd);

primitive_attr bmm1_attr;
bmm1_attr.set_scratchpad_mode(scratchpad_mode::user);
Expand All @@ -130,7 +134,7 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt,
softmax_attr.set_scratchpad_mode(scratchpad_mode::user);
auto softmax_pd = softmax_forward::primitive_desc(eng,
prop_kind::forward_inference, algorithm::softmax_accurate, score_md,
score_md, /* axis = */ score_md.get_ndims() - 1, softmax_attr);
probs_md, /* axis = */ score_md.get_ndims() - 1, softmax_attr);
auto softmax_prim = softmax_forward(softmax_pd);

// attention_output = attention_probs x value
Expand All @@ -139,7 +143,7 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt,
primitive_attr bmm2_attr;
bmm2_attr.set_scratchpad_mode(scratchpad_mode::user);
auto bmm2_pd = matmul::primitive_desc(
eng, score_md, value_md, output_md, bmm2_attr);
eng, probs_md, value_md, output_md, bmm2_attr);
auto bmm2_prim = matmul(bmm2_pd);

// Create memory objects
Expand Down Expand Up @@ -183,6 +187,7 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt,

// allocate intermediate memory
auto m_score = memory(score_md, eng);
auto m_probs = memory(probs_md, eng);
auto m_scratchpad = memory(scratchpad_md, eng);

const auto loop = [&]() {
Expand All @@ -197,11 +202,11 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt,
{DNNL_ARG_SCRATCHPAD, m_scratchpad}});

softmax_prim.execute(strm,
{{DNNL_ARG_SRC, m_score}, {DNNL_ARG_DST, m_score},
{{DNNL_ARG_SRC, m_score}, {DNNL_ARG_DST, m_probs},
{DNNL_ARG_SCRATCHPAD, m_scratchpad}});

bmm2_prim.execute(strm,
{{DNNL_ARG_SRC, m_score}, {DNNL_ARG_WEIGHTS, m_value},
{{DNNL_ARG_SRC, m_probs}, {DNNL_ARG_WEIGHTS, m_value},
{DNNL_ARG_DST, m_output},
{DNNL_ARG_SCRATCHPAD, m_scratchpad}});
};
Expand Down Expand Up @@ -282,10 +287,13 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt,
// Incremental IDs used to create logical tensors and operations.
size_t id = 0;

// Intermediate data type
const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32;

// score = query x key.T
auto query = logical_tensor(id++, dt, qv_sz, layout_type::strided);
auto key = logical_tensor(id++, dt, k_sz, layout_type::strided);
auto score = logical_tensor(id++, dt, score_sz, layout_type::strided);
auto score = logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
auto bmm1 = op(id++, op::kind::MatMul, "bmm1");
bmm1.set_attr<bool>(op::attr::transpose_b, true);
bmm1.add_inputs({query, key});
Expand All @@ -294,15 +302,15 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt,
// scaled_score = score / scale
auto scale = logical_tensor(id++, dt, scale_sz, layout_type::strided);
auto scaled_score
= logical_tensor(id++, dt, score_sz, layout_type::strided);
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
auto scale_div = op(id++, op::kind::Divide, "scale_div");
scale_div.add_inputs({score, scale});
scale_div.add_outputs({scaled_score});

// masked_score = scaled_score + mask
auto mask = logical_tensor(id++, dt, mask_sz, layout_type::strided);
auto masked_score
= logical_tensor(id++, dt, score_sz, layout_type::strided);
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
auto mask_add = op(id++, op::kind::Add, "mask_add");
mask_add.add_inputs({scaled_score, mask});
mask_add.add_outputs({masked_score});
Expand Down
9 changes: 6 additions & 3 deletions examples/graph/sdpa_stacked_qkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt,
// Incremental IDs used to create logical tensors and operations.
size_t id = 0;

// Intermediate data type
const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32;

// This logical tensor is not part of the graph but is used to generate the
// big chunk of device memory which should be already there in real user
// application or framework.
Expand All @@ -152,7 +155,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt,
auto key = logical_tensor(id++, dt, qkv_sz, qkv_strides);
// Though query and key are non-contiguous above, the output score is still
// contiguous.
auto score = logical_tensor(id++, dt, score_sz, layout_type::strided);
auto score = logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
auto bmm1 = op(id++, op::kind::MatMul, "bmm1");
bmm1.set_attr<bool>(op::attr::transpose_b, true);
bmm1.add_inputs({query, key});
Expand All @@ -161,15 +164,15 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt,
// scaled_score = score / scale
auto scale = logical_tensor(id++, dt, scale_sz, layout_type::strided);
auto scaled_score
= logical_tensor(id++, dt, score_sz, layout_type::strided);
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
auto scale_div = op(id++, op::kind::Divide, "scale_div");
scale_div.add_inputs({score, scale});
scale_div.add_outputs({scaled_score});

// masked_score = scaled_score + mask
auto mask = logical_tensor(id++, dt, mask_sz, layout_type::strided);
auto masked_score
= logical_tensor(id++, dt, score_sz, layout_type::strided);
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
auto mask_add = op(id++, op::kind::Add, "mask_add");
mask_add.add_inputs({scaled_score, mask});
mask_add.add_outputs({masked_score});
Expand Down
17 changes: 17 additions & 0 deletions src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ status_t sdp_primitive_config_t::initial_check(
graph::op_kind::Add, graph::op_kind::Select,
graph::op_kind::SoftMax};
op_ptr mm1 = nullptr, mm2 = nullptr, scale = nullptr;
bool f32_inter = true;
for (const auto &cur_op : sg->get_ops()) {
const auto &op_kind = cur_op->get_kind();
if (op_kind == graph::op_kind::DynamicDequantize
Expand Down Expand Up @@ -213,6 +214,10 @@ status_t sdp_primitive_config_t::initial_check(
auto post_op = get_post_op(cur_op);
if (post_op && mm1_post_op_kind.count(post_op->get_kind())) {
mm1 = cur_op;
const auto &lt_score
= mm1->get_output_value(0)->get_logical_tensor();
f32_inter = f32_inter
&& (ltw(lt_score).data_type() == data_type::f32);
// Not support select between mm1 and scale(optional)
// GPT-J:[mm1] --> [select] --> [scale]* --> [mask]* --> ...
VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select,
Expand All @@ -224,11 +229,20 @@ status_t sdp_primitive_config_t::initial_check(
// Scale exists, update post_op and traverse to next op
scale = post_op;
post_op = get_post_op(post_op);
const auto &lt_ss
= scale->get_output_value(0)->get_logical_tensor();
f32_inter = f32_inter
&& (ltw(lt_ss).data_type() == data_type::f32);
}
// mask
if (post_op) {
if (post_op->get_kind() == graph::op_kind::Add) {
// Mask exists, update post_op and traverse to next op
const auto mask = post_op;
const auto &lt_ms
= mask->get_output_value(0)->get_logical_tensor();
f32_inter = f32_inter
&& (ltw(lt_ms).data_type() == data_type::f32);
post_op = get_post_op(post_op);
}
// Not support select after scale(optional) and mask(optional)
Expand All @@ -245,6 +259,9 @@ status_t sdp_primitive_config_t::initial_check(
}
}

VCHECK_SDP_PRIMITIVE(f32_inter, status::invalid_graph,
"only supports f32 intermediates.");

auto find_graph_inport = [&inputs](const std::shared_ptr<value_t> &val) {
auto tmp_val = val;
while (tmp_val->has_producer()) {
Expand Down
4 changes: 2 additions & 2 deletions src/graph/backend/dnnl/patterns/sdp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_fusion_gpu)
.set_attr<FCreatePattern>("FCreatePattern",
[](const std::shared_ptr<pb_graph_t> &pgraph) -> void {
auto matmul_qk = pgraph->append_op(graph::op_kind::MatMul);
auto optional_scale_and_mask = optional_scale_and_masks(
pgraph, matmul_qk, /*check_xf16*/ true);
auto optional_scale_and_mask
= optional_scale_and_masks(pgraph, matmul_qk);
auto softmax = pgraph->append_op(graph::op_kind::SoftMax,
{in_edge(0, optional_scale_and_mask, 0)});
auto matmul_v = pgraph->append_op(
Expand Down
43 changes: 27 additions & 16 deletions src/graph/interface/op_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,17 @@ DNNL_GRAPH_OP_SCHEMA(Add, 1,
.set_num_inputs(2)
.set_num_outputs(1)
.set_commutative_inputs()
.set_input(0, "src_0", "T")
.set_input(1, "src_1", "T")
.set_output(0, "dst", "T")
.set_input(0, "src_0", "T1")
.set_input(1, "src_1", "T2")
.set_output(0, "dst", "T3")
.set_attr(op_attr::auto_broadcast, false, attribute_kind::s,
"numpy", {"none", "numpy"})
.set_type_constraints(
"T", {data_type::f32, data_type::bf16, data_type::f16})
"T1", {data_type::f32, data_type::bf16, data_type::f16})
.set_type_constraints(
"T2", {data_type::f32, data_type::bf16, data_type::f16})
.set_type_constraints(
"T3", {data_type::f32, data_type::bf16, data_type::f16})
.set_shape_inference_function(
infer_elemwise_arithmetic_output_shape))

Expand Down Expand Up @@ -684,10 +688,13 @@ DNNL_GRAPH_OP_SCHEMA(MatMul, 1,
.set_input(0, "src", "T")
.set_input(1, "weights", "T")
.set_input(2, "bias", "T")
.set_output(0, "dst", "T")
.set_output(0, "dst", "T1")
.set_type_constraints(
"T", {data_type::f32, data_type::bf16, data_type::f16})
.set_type_constraints(
"T1", {data_type::f32, data_type::bf16, data_type::f16})
.set_shape_inference_function(infer_matmul_output_shape)
.set_op_def_constraint_function(check_matmul_dtype)
.SET_MATMUL_COMMON_ATTRS)

DNNL_GRAPH_OP_SCHEMA(Maximum, 1,
Expand Down Expand Up @@ -788,9 +795,6 @@ DNNL_GRAPH_OP_SCHEMA(MishBackward, 1,
"T", {data_type::f32, data_type::bf16, data_type::f16})
.set_shape_inference_function(infer_identity_output_shape))

// TODO(Yixin): for Multiply. input and output needs to have the same dtypes
// But in current pytorch bridge's type promotion system, there's no
// such constraints. So this feature is postponed.
DNNL_GRAPH_OP_SCHEMA(Multiply, 1,
op_schema_t()
.set_num_inputs(2)
Expand Down Expand Up @@ -1029,12 +1033,15 @@ DNNL_GRAPH_OP_SCHEMA(SoftMax, 1,
op_schema_t()
.set_num_inputs(1)
.set_num_outputs(1)
.set_input(0, "src", "T")
.set_output(0, "dst", "T")
.set_input(0, "src", "T1")
.set_output(0, "dst", "T2")
.set_attr(op_attr::axis, false, attribute_kind::i, (int64_t)1)
.set_type_constraints(
"T", {data_type::f32, data_type::bf16, data_type::f16})
.set_shape_inference_function(infer_identity_output_shape))
"T1", {data_type::f32, data_type::bf16, data_type::f16})
.set_type_constraints(
"T2", {data_type::f32, data_type::bf16, data_type::f16})
.set_shape_inference_function(infer_identity_output_shape)
.set_op_def_constraint_function(check_softmax_dtype))

DNNL_GRAPH_OP_SCHEMA(SoftMaxBackward, 1,
op_schema_t()
Expand Down Expand Up @@ -1121,13 +1128,17 @@ DNNL_GRAPH_OP_SCHEMA(Subtract, 1,
op_schema_t()
.set_num_inputs(2)
.set_num_outputs(1)
.set_input(0, "src_0", "T")
.set_input(1, "src_1", "T")
.set_output(0, "dst", "T")
.set_input(0, "src_0", "T1")
.set_input(1, "src_1", "T2")
.set_output(0, "dst", "T3")
.set_attr(op_attr::auto_broadcast, false, attribute_kind::s,
"numpy", {"none", "numpy"})
.set_type_constraints(
"T", {data_type::f32, data_type::bf16, data_type::f16})
"T1", {data_type::f32, data_type::bf16, data_type::f16})
.set_type_constraints(
"T2", {data_type::f32, data_type::bf16, data_type::f16})
.set_type_constraints(
"T3", {data_type::f32, data_type::bf16, data_type::f16})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires some documentation about type promotion as users might wonder what happens for example with f16 <- f16 + bf16.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, we don't allow f16 + bf16. It's mentioned in the "supported data types" section in the op document. When src0 and src1 have different data types, one of them should be f32 and the other one (f16 or bf16) will be promoted to f32 for calculation.

.set_shape_inference_function(
infer_elemwise_arithmetic_output_shape))

Expand Down
Loading
Loading