diff --git a/doc/graph/fusion_patterns/sdpa.md b/doc/graph/fusion_patterns/sdpa.md index 8be979adc6b..299a4f077b6 100644 --- a/doc/graph/fusion_patterns/sdpa.md +++ b/doc/graph/fusion_patterns/sdpa.md @@ -128,9 +128,9 @@ platforms follow the general description in @ref dev_guide_data_types. 4. GPU - Optimized implementation is available for 4D Q/K/V tensors with shape defined as (N, H, S, D). - - Optimized implementation is available for floating-point SDPA with `f16` - data type and `D <= 256` on Intel Graphics Products with Intel(R) Xe Matrix - Extensions (Intel(R) XMX) support. + - Optimized implementation is available for `f16` or `bf16` SDPA with `f32` + intermediate data type and `D <= 256` on Intel Graphics Products with + Intel(R) Xe Matrix Extensions (Intel(R) XMX) support. ## Example diff --git a/doc/graph/operations/Add.md b/doc/graph/operations/Add.md index 6f5342b382c..5fef1ab7d7e 100644 --- a/doc/graph/operations/Add.md +++ b/doc/graph/operations/Add.md @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is Add operation supports the following data type combinations. -| Src_0 / Src_1 | Dst | -|:--------------|:-----| -| f32 | f32 | -| bf16 | bf16 | -| f16 | f16 | +| Src_0 | Src_1 | Dst | +|:----------|:----------|:-----| +| f32 | f32 | f32 | +| bf16 | bf16 | bf16 | +| f16 | f16 | f16 | +| f32 | bf16, f16 | f32 | +| bf16, f16 | f32 | f32 | diff --git a/doc/graph/operations/Divide.md b/doc/graph/operations/Divide.md index 8c4ab535544..11689c9b7eb 100644 --- a/doc/graph/operations/Divide.md +++ b/doc/graph/operations/Divide.md @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is Divide operation supports the following data type combinations. -| Src_0 / Src_1 | Dst | -|:--------------|:-----| -| f32 | f32 | -| bf16 | bf16 | -| f16 | f16 | +| Src_0 | Src_1 | Dst | +|:----------|:----------|:-----| +| f32 | f32 | f32 | +| bf16 | bf16 | bf16 | +| f16 | f16 | f16 | +| f32 | bf16, f16 | f32 | +| bf16, f16 | f32 | f32 | diff --git a/doc/graph/operations/MatMul.md b/doc/graph/operations/MatMul.md index d2b4cc89b0f..7879393969a 100644 --- a/doc/graph/operations/MatMul.md +++ b/doc/graph/operations/MatMul.md @@ -61,8 +61,8 @@ constructing an operation. MatMul operation supports the following data type combinations. -| Src | Weights | Bias | Dst | -|:-----|:--------|:-----|:-----| -| f32 | f32 | f32 | f32 | -| bf16 | bf16 | bf16 | bf16 | -| f16 | f16 | f16 | f16 | +| Src | Weights | Bias | Dst | +|:-----|:--------|:-----|:----------| +| f32 | f32 | f32 | f32 | +| bf16 | bf16 | bf16 | f32, bf16 | +| f16 | f16 | f16 | f32, f16 | diff --git a/doc/graph/operations/Multiply.md b/doc/graph/operations/Multiply.md index 625bfea10d2..24e09881e10 100644 --- a/doc/graph/operations/Multiply.md +++ b/doc/graph/operations/Multiply.md @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is Multiply operation supports the following data type combinations. -| Src_0 / Src_1 | Dst | -|:--------------|:-----| -| f32 | f32 | -| bf16 | bf16 | -| f16 | f16 | +| Src_0 | Src_1 | Dst | +|:----------|:----------|:-----| +| f32 | f32 | f32 | +| bf16 | bf16 | bf16 | +| f16 | f16 | f16 | +| f32 | bf16, f16 | f32 | +| bf16, f16 | f32 | f32 | diff --git a/doc/graph/operations/Softmax.md b/doc/graph/operations/Softmax.md index 6655eb218d6..467634b1d05 100644 --- a/doc/graph/operations/Softmax.md +++ b/doc/graph/operations/Softmax.md @@ -36,8 +36,8 @@ constructing an operation. SoftMax operation supports the following data type combinations. -| Src | Dst | -|:-----|:-----| -| f32 | f32 | -| bf16 | bf16 | -| f16 | f16 | +| Src | Dst | +|:-----|:----------------| +| f32 | f32, bf16, f16 | +| bf16 | bf16 | +| f16 | f16 | diff --git a/doc/graph/operations/Subtract.md b/doc/graph/operations/Subtract.md index 28138271a5a..bca45816cc8 100644 --- a/doc/graph/operations/Subtract.md +++ b/doc/graph/operations/Subtract.md @@ -44,8 +44,10 @@ different and auto-broadcasting is allowed if `auto_broadcast` attributes is Subtract operation supports the following data type combinations. -| Src_0 / Src_1 | Dst | -|:--------------|:-----| -| f32 | f32 | -| bf16 | bf16 | -| f16 | f16 | +| Src_0 | Src_1 | Dst | +|:----------|:----------|:-----| +| f32 | f32 | f32 | +| bf16 | bf16 | bf16 | +| f16 | f16 | f16 | +| f32 | bf16, f16 | f32 | +| bf16, f16 | f32 | f32 | diff --git a/doc/graph/programming_model/low_precision.md b/doc/graph/programming_model/low_precision.md index 35118771b96..83b7744ba25 100644 --- a/doc/graph/programming_model/low_precision.md +++ b/doc/graph/programming_model/low_precision.md @@ -52,7 +52,6 @@ Graph operations support bf16 and f16 data types. A TypeCast operation performing down conversion should be inserted clearly to indicate the use of low numeric precision. oneDNN Graph implementation fully -honors the API-specified numeric precision and only performs the computation -using the API-specified or higher numeric precision. +honors the API-specified numeric precision. @img{bf16_programming.jpg,Figure 2: Overview of bf16 programming model.,80%,} diff --git a/examples/graph/sdpa.cpp b/examples/graph/sdpa.cpp index f395f3172f9..99f3dc2c64b 100644 --- a/examples/graph/sdpa.cpp +++ b/examples/graph/sdpa.cpp @@ -96,6 +96,9 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt, // Create dnnl::stream. dnnl::stream strm(eng); + // Intermediate data type + const memory::data_type dt_inter = memory::data_type::f32; + // Prepare input and output shapes to construct the sdpa graph. const memory::dims q_sz = {p.mb, p.head_num, p.query_num, p.head_size}; const memory::dims k_sz = {p.mb, p.head_num, p.head_size, p.seq_len}; @@ -110,9 +113,10 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt, // All combined in a single matmul primitive. auto query_md = memory::desc(q_sz, dt, memory::format_tag::abcd); auto key_md = memory::desc(k_sz, dt, memory::format_tag::abdc); - auto score_md = memory::desc(score_sz, dt, memory::format_tag::abcd); + auto score_md = memory::desc(score_sz, dt_inter, memory::format_tag::abcd); auto scale_md = memory::desc(scale_sz, dt, memory::format_tag::abcd); auto mask_md = memory::desc(mask_sz, dt, memory::format_tag::abcd); + auto probs_md = memory::desc(score_sz, dt, memory::format_tag::abcd); primitive_attr bmm1_attr; bmm1_attr.set_scratchpad_mode(scratchpad_mode::user); @@ -130,7 +134,7 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt, softmax_attr.set_scratchpad_mode(scratchpad_mode::user); auto softmax_pd = softmax_forward::primitive_desc(eng, prop_kind::forward_inference, algorithm::softmax_accurate, score_md, - score_md, /* axis = */ score_md.get_ndims() - 1, softmax_attr); + probs_md, /* axis = */ score_md.get_ndims() - 1, softmax_attr); auto softmax_prim = softmax_forward(softmax_pd); // attention_output = attention_probs x value @@ -139,7 +143,7 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt, primitive_attr bmm2_attr; bmm2_attr.set_scratchpad_mode(scratchpad_mode::user); auto bmm2_pd = matmul::primitive_desc( - eng, score_md, value_md, output_md, bmm2_attr); + eng, probs_md, value_md, output_md, bmm2_attr); auto bmm2_prim = matmul(bmm2_pd); // Create memory objects @@ -183,6 +187,7 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt, // allocate intermediate memory auto m_score = memory(score_md, eng); + auto m_probs = memory(probs_md, eng); auto m_scratchpad = memory(scratchpad_md, eng); const auto loop = [&]() { @@ -197,11 +202,11 @@ void bench_sdpa_primitives(engine::kind ekind, memory::data_type dt, {DNNL_ARG_SCRATCHPAD, m_scratchpad}}); softmax_prim.execute(strm, - {{DNNL_ARG_SRC, m_score}, {DNNL_ARG_DST, m_score}, + {{DNNL_ARG_SRC, m_score}, {DNNL_ARG_DST, m_probs}, {DNNL_ARG_SCRATCHPAD, m_scratchpad}}); bmm2_prim.execute(strm, - {{DNNL_ARG_SRC, m_score}, {DNNL_ARG_WEIGHTS, m_value}, + {{DNNL_ARG_SRC, m_probs}, {DNNL_ARG_WEIGHTS, m_value}, {DNNL_ARG_DST, m_output}, {DNNL_ARG_SCRATCHPAD, m_scratchpad}}); }; @@ -282,10 +287,13 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, // Incremental IDs used to create logical tensors and operations. size_t id = 0; + // Intermediate data type + const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32; + // score = query x key.T auto query = logical_tensor(id++, dt, qv_sz, layout_type::strided); auto key = logical_tensor(id++, dt, k_sz, layout_type::strided); - auto score = logical_tensor(id++, dt, score_sz, layout_type::strided); + auto score = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); auto bmm1 = op(id++, op::kind::MatMul, "bmm1"); bmm1.set_attr(op::attr::transpose_b, true); bmm1.add_inputs({query, key}); @@ -294,7 +302,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, // scaled_score = score / scale auto scale = logical_tensor(id++, dt, scale_sz, layout_type::strided); auto scaled_score - = logical_tensor(id++, dt, score_sz, layout_type::strided); + = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); auto scale_div = op(id++, op::kind::Divide, "scale_div"); scale_div.add_inputs({score, scale}); scale_div.add_outputs({scaled_score}); @@ -302,7 +310,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, // masked_score = scaled_score + mask auto mask = logical_tensor(id++, dt, mask_sz, layout_type::strided); auto masked_score - = logical_tensor(id++, dt, score_sz, layout_type::strided); + = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); auto mask_add = op(id++, op::kind::Add, "mask_add"); mask_add.add_inputs({scaled_score, mask}); mask_add.add_outputs({masked_score}); diff --git a/examples/graph/sdpa_stacked_qkv.cpp b/examples/graph/sdpa_stacked_qkv.cpp index 16bdffe3f52..29920224192 100644 --- a/examples/graph/sdpa_stacked_qkv.cpp +++ b/examples/graph/sdpa_stacked_qkv.cpp @@ -142,6 +142,9 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, // Incremental IDs used to create logical tensors and operations. size_t id = 0; + // Intermediate data type + const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32; + // This logical tensor is not part of the graph but is used to generate the // big chunk of device memory which should be already there in real user // application or framework. @@ -152,7 +155,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, auto key = logical_tensor(id++, dt, qkv_sz, qkv_strides); // Though query and key are non-contiguous above, the output score is still // contiguous. - auto score = logical_tensor(id++, dt, score_sz, layout_type::strided); + auto score = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); auto bmm1 = op(id++, op::kind::MatMul, "bmm1"); bmm1.set_attr(op::attr::transpose_b, true); bmm1.add_inputs({query, key}); @@ -161,7 +164,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, // scaled_score = score / scale auto scale = logical_tensor(id++, dt, scale_sz, layout_type::strided); auto scaled_score - = logical_tensor(id++, dt, score_sz, layout_type::strided); + = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); auto scale_div = op(id++, op::kind::Divide, "scale_div"); scale_div.add_inputs({score, scale}); scale_div.add_outputs({scaled_score}); @@ -169,7 +172,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt, // masked_score = scaled_score + mask auto mask = logical_tensor(id++, dt, mask_sz, layout_type::strided); auto masked_score - = logical_tensor(id++, dt, score_sz, layout_type::strided); + = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); auto mask_add = op(id++, op::kind::Add, "mask_add"); mask_add.add_inputs({scaled_score, mask}); mask_add.add_outputs({masked_score}); diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp index b05ddf4004d..25a1a3388d2 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp @@ -180,6 +180,7 @@ status_t sdp_primitive_config_t::initial_check( graph::op_kind::Add, graph::op_kind::Select, graph::op_kind::SoftMax}; op_ptr mm1 = nullptr, mm2 = nullptr, scale = nullptr; + bool f32_inter = true; for (const auto &cur_op : sg->get_ops()) { const auto &op_kind = cur_op->get_kind(); if (op_kind == graph::op_kind::DynamicDequantize @@ -213,6 +214,10 @@ status_t sdp_primitive_config_t::initial_check( auto post_op = get_post_op(cur_op); if (post_op && mm1_post_op_kind.count(post_op->get_kind())) { mm1 = cur_op; + const auto <_score + = mm1->get_output_value(0)->get_logical_tensor(); + f32_inter = f32_inter + && (ltw(lt_score).data_type() == data_type::f32); // Not support select between mm1 and scale(optional) // GPT-J:[mm1] --> [select] --> [scale]* --> [mask]* --> ... VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select, @@ -224,11 +229,20 @@ status_t sdp_primitive_config_t::initial_check( // Scale exists, update post_op and traverse to next op scale = post_op; post_op = get_post_op(post_op); + const auto <_ss + = scale->get_output_value(0)->get_logical_tensor(); + f32_inter = f32_inter + && (ltw(lt_ss).data_type() == data_type::f32); } // mask if (post_op) { if (post_op->get_kind() == graph::op_kind::Add) { // Mask exists, update post_op and traverse to next op + const auto mask = post_op; + const auto <_ms + = mask->get_output_value(0)->get_logical_tensor(); + f32_inter = f32_inter + && (ltw(lt_ms).data_type() == data_type::f32); post_op = get_post_op(post_op); } // Not support select after scale(optional) and mask(optional) @@ -245,6 +259,9 @@ status_t sdp_primitive_config_t::initial_check( } } + VCHECK_SDP_PRIMITIVE(f32_inter, status::invalid_graph, + "only supports f32 intermediates."); + auto find_graph_inport = [&inputs](const std::shared_ptr &val) { auto tmp_val = val; while (tmp_val->has_producer()) { diff --git a/src/graph/backend/dnnl/patterns/sdp.cpp b/src/graph/backend/dnnl/patterns/sdp.cpp index bf176870c0f..928f68ce8ec 100644 --- a/src/graph/backend/dnnl/patterns/sdp.cpp +++ b/src/graph/backend/dnnl/patterns/sdp.cpp @@ -142,8 +142,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_fusion_gpu) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) -> void { auto matmul_qk = pgraph->append_op(graph::op_kind::MatMul); - auto optional_scale_and_mask = optional_scale_and_masks( - pgraph, matmul_qk, /*check_xf16*/ true); + auto optional_scale_and_mask + = optional_scale_and_masks(pgraph, matmul_qk); auto softmax = pgraph->append_op(graph::op_kind::SoftMax, {in_edge(0, optional_scale_and_mask, 0)}); auto matmul_v = pgraph->append_op( diff --git a/src/graph/interface/op_def.hpp b/src/graph/interface/op_def.hpp index 7d3c03aabae..760197239bb 100644 --- a/src/graph/interface/op_def.hpp +++ b/src/graph/interface/op_def.hpp @@ -54,13 +54,17 @@ DNNL_GRAPH_OP_SCHEMA(Add, 1, .set_num_inputs(2) .set_num_outputs(1) .set_commutative_inputs() - .set_input(0, "src_0", "T") - .set_input(1, "src_1", "T") - .set_output(0, "dst", "T") + .set_input(0, "src_0", "T1") + .set_input(1, "src_1", "T2") + .set_output(0, "dst", "T3") .set_attr(op_attr::auto_broadcast, false, attribute_kind::s, "numpy", {"none", "numpy"}) .set_type_constraints( - "T", {data_type::f32, data_type::bf16, data_type::f16}) + "T1", {data_type::f32, data_type::bf16, data_type::f16}) + .set_type_constraints( + "T2", {data_type::f32, data_type::bf16, data_type::f16}) + .set_type_constraints( + "T3", {data_type::f32, data_type::bf16, data_type::f16}) .set_shape_inference_function( infer_elemwise_arithmetic_output_shape)) @@ -684,10 +688,13 @@ DNNL_GRAPH_OP_SCHEMA(MatMul, 1, .set_input(0, "src", "T") .set_input(1, "weights", "T") .set_input(2, "bias", "T") - .set_output(0, "dst", "T") + .set_output(0, "dst", "T1") .set_type_constraints( "T", {data_type::f32, data_type::bf16, data_type::f16}) + .set_type_constraints( + "T1", {data_type::f32, data_type::bf16, data_type::f16}) .set_shape_inference_function(infer_matmul_output_shape) + .set_op_def_constraint_function(check_matmul_dtype) .SET_MATMUL_COMMON_ATTRS) DNNL_GRAPH_OP_SCHEMA(Maximum, 1, @@ -788,9 +795,6 @@ DNNL_GRAPH_OP_SCHEMA(MishBackward, 1, "T", {data_type::f32, data_type::bf16, data_type::f16}) .set_shape_inference_function(infer_identity_output_shape)) -// TODO(Yixin): for Multiply. input and output needs to have the same dtypes -// But in current pytorch bridge's type promotion system, there's no -// such constraints. So this feature is postponed. DNNL_GRAPH_OP_SCHEMA(Multiply, 1, op_schema_t() .set_num_inputs(2) @@ -1029,12 +1033,15 @@ DNNL_GRAPH_OP_SCHEMA(SoftMax, 1, op_schema_t() .set_num_inputs(1) .set_num_outputs(1) - .set_input(0, "src", "T") - .set_output(0, "dst", "T") + .set_input(0, "src", "T1") + .set_output(0, "dst", "T2") .set_attr(op_attr::axis, false, attribute_kind::i, (int64_t)1) .set_type_constraints( - "T", {data_type::f32, data_type::bf16, data_type::f16}) - .set_shape_inference_function(infer_identity_output_shape)) + "T1", {data_type::f32, data_type::bf16, data_type::f16}) + .set_type_constraints( + "T2", {data_type::f32, data_type::bf16, data_type::f16}) + .set_shape_inference_function(infer_identity_output_shape) + .set_op_def_constraint_function(check_softmax_dtype)) DNNL_GRAPH_OP_SCHEMA(SoftMaxBackward, 1, op_schema_t() @@ -1121,13 +1128,17 @@ DNNL_GRAPH_OP_SCHEMA(Subtract, 1, op_schema_t() .set_num_inputs(2) .set_num_outputs(1) - .set_input(0, "src_0", "T") - .set_input(1, "src_1", "T") - .set_output(0, "dst", "T") + .set_input(0, "src_0", "T1") + .set_input(1, "src_1", "T2") + .set_output(0, "dst", "T3") .set_attr(op_attr::auto_broadcast, false, attribute_kind::s, "numpy", {"none", "numpy"}) .set_type_constraints( - "T", {data_type::f32, data_type::bf16, data_type::f16}) + "T1", {data_type::f32, data_type::bf16, data_type::f16}) + .set_type_constraints( + "T2", {data_type::f32, data_type::bf16, data_type::f16}) + .set_type_constraints( + "T3", {data_type::f32, data_type::bf16, data_type::f16}) .set_shape_inference_function( infer_elemwise_arithmetic_output_shape)) diff --git a/src/graph/interface/op_def_constraint.cpp b/src/graph/interface/op_def_constraint.cpp index d2f22072d3b..0fa1af04405 100644 --- a/src/graph/interface/op_def_constraint.cpp +++ b/src/graph/interface/op_def_constraint.cpp @@ -85,6 +85,45 @@ bool check_bn_data_type(const op_t *n) { return true; } +// For MatMul, it's required that src and wei have the same data type. When +// src/wei is xf16, dst can be f32 or xf16 (the same type as src/wei). We can +// disable this check to allow f32f32xf16 when there is a request. +bool check_matmul_dtype(const op_t *mm) { + const auto &inputs = mm->get_input_values(); + const auto &outputs = mm->get_output_values(); + + const logical_tensor_t &src = inputs[0]->get_logical_tensor(); + const logical_tensor_t &dst = outputs[0]->get_logical_tensor(); + if (src.data_type != dst.data_type) { + if (dst.data_type != data_type::f32) { + VCHECK_SHAPE_INFER(false, "%s, %s src + %s dst is not supported", + op_t::kind2str(mm->get_kind()).c_str(), + dnnl_dt2str(src.data_type), dnnl_dt2str(dst.data_type)); + } + } + + return true; +} + +// For SoftMax, if the src is f32, dst can be xf16. Otherwise, src and dst +// should have the same data type. +bool check_softmax_dtype(const op_t *n) { + const auto &inputs = n->get_input_values(); + const auto &outputs = n->get_output_values(); + + const logical_tensor_t &src = inputs[0]->get_logical_tensor(); + const logical_tensor_t &dst = outputs[0]->get_logical_tensor(); + if (src.data_type != dst.data_type) { + if (src.data_type != data_type::f32) { + VCHECK_SHAPE_INFER(false, "%s, %s src + %s dst is not supported", + op_t::kind2str(n->get_kind()).c_str(), + dnnl_dt2str(src.data_type), dnnl_dt2str(dst.data_type)); + } + } + + return true; +} + // check function for data_type of LayerNorm and GroupNorm. // only when data is bf16, gamma/beta/mean/var can be bf16. // If data is bf16, gamma/beta/mean/var can be f32 or bf16. diff --git a/src/graph/interface/op_def_constraint.hpp b/src/graph/interface/op_def_constraint.hpp index 3d519383fce..d9e19457935 100644 --- a/src/graph/interface/op_def_constraint.hpp +++ b/src/graph/interface/op_def_constraint.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022-2024 Intel Corporation +* Copyright 2022-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,10 @@ bool check_pads(const op_t *n); bool check_bn_data_type(const op_t *n); +bool check_matmul_dtype(const op_t *n); + +bool check_softmax_dtype(const op_t *n); + bool check_ln_gn_data_type(const op_t *n); bool check_typecast_data_type(const op_t *n); diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all index c4d0e7da803..b8562d7bc47 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all @@ -15,6 +15,13 @@ --reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json --reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json --reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json +# f16 inputs + f32 intermediates + f16 outputs +--reset --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json +--reset --dt=1:f16+2:f16+3:f16+4:f16+6:f16+104:f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json + +# bf16 inputs + f32 intermediates + bf16 outputs +--reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+5:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json +--reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json # int8 graphs --reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci index 39780ff3a8a..84daddaeb4c 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci @@ -13,6 +13,11 @@ --reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json --reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json --reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json +# f16 inputs + f32 intermediates + f16 outputs +--reset --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json +# bf16 inputs + f32 intermediates + bf16 outputs +--reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+5:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json + # int8 graphs --reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json index b16217b31ca..6d73a3cd6fc 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json +++ b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json @@ -1,17 +1,17 @@ { - "version": "3.7.0", + "version": "3.8.0", "engine_kind": "cpu", "fpmath_mode": "strict", "fpmath_mode_apply_to_int": "false", "input_ports": [ - 0, 1, - 3, - 8, - 11 + 2, + 4, + 5, + 3 ], "output_ports": [ - 12 + 6 ], "graph": [ { @@ -30,7 +30,7 @@ }, "inputs": [ { - "id": 0, + "id": 1, "dtype": "f32", "shape": [ 1, @@ -48,7 +48,7 @@ "property_type": "undef" }, { - "id": 1, + "id": 2, "dtype": "f32", "shape": [ 1, @@ -68,7 +68,7 @@ ], "outputs": [ { - "id": 2, + "id": 101, "dtype": "f32", "shape": [ 1, @@ -99,7 +99,7 @@ }, "inputs": [ { - "id": 2, + "id": 101, "dtype": "f32", "shape": [ 1, @@ -117,7 +117,7 @@ "property_type": "undef" }, { - "id": 3, + "id": 4, "dtype": "f32", "shape": [ 1 @@ -131,7 +131,7 @@ ], "outputs": [ { - "id": 4, + "id": 102, "dtype": "f32", "shape": [ 1, @@ -151,7 +151,7 @@ ] }, { - "id": 2, + "id": 40, "name": "genindex_row", "kind": "GenIndex", "attrs": { @@ -162,7 +162,7 @@ }, "inputs": [ { - "id": 4, + "id": 102, "dtype": "f32", "shape": [ 1, @@ -182,7 +182,7 @@ ], "outputs": [ { - "id": 5, + "id": 1021, "dtype": "s32", "shape": [ 1, @@ -202,7 +202,7 @@ ] }, { - "id": 3, + "id": 41, "name": "genindex_col", "kind": "GenIndex", "attrs": { @@ -213,7 +213,7 @@ }, "inputs": [ { - "id": 4, + "id": 102, "dtype": "f32", "shape": [ 1, @@ -233,7 +233,7 @@ ], "outputs": [ { - "id": 6, + "id": 1022, "dtype": "s32", "shape": [ 1, @@ -253,7 +253,7 @@ ] }, { - "id": 4, + "id": 42, "name": "mask_greater_equal", "kind": "GreaterEqual", "attrs": { @@ -264,7 +264,7 @@ }, "inputs": [ { - "id": 5, + "id": 1021, "dtype": "s32", "shape": [ 1, @@ -282,7 +282,7 @@ "property_type": "undef" }, { - "id": 6, + "id": 1022, "dtype": "s32", "shape": [ 1, @@ -302,7 +302,7 @@ ], "outputs": [ { - "id": 7, + "id": 1023, "dtype": "boolean", "shape": [ 1, @@ -322,7 +322,7 @@ ] }, { - "id": 5, + "id": 2, "name": "Select", "kind": "Select", "attrs": { @@ -333,7 +333,7 @@ }, "inputs": [ { - "id": 7, + "id": 1023, "dtype": "boolean", "shape": [ 1, @@ -351,7 +351,7 @@ "property_type": "undef" }, { - "id": 4, + "id": 102, "dtype": "f32", "shape": [ 1, @@ -369,7 +369,7 @@ "property_type": "undef" }, { - "id": 8, + "id": 5, "dtype": "f32", "shape": [ 1 @@ -383,7 +383,7 @@ ], "outputs": [ { - "id": 9, + "id": 103, "dtype": "f32", "shape": [ 1, @@ -403,7 +403,7 @@ ] }, { - "id": 6, + "id": 3, "name": "softmax", "kind": "SoftMax", "attrs": { @@ -414,7 +414,7 @@ }, "inputs": [ { - "id": 9, + "id": 103, "dtype": "f32", "shape": [ 1, @@ -434,7 +434,7 @@ ], "outputs": [ { - "id": 10, + "id": 104, "dtype": "f32", "shape": [ 1, @@ -454,7 +454,7 @@ ] }, { - "id": 7, + "id": 4, "name": "matmul_v", "kind": "MatMul", "attrs": { @@ -469,7 +469,7 @@ }, "inputs": [ { - "id": 10, + "id": 104, "dtype": "f32", "shape": [ 1, @@ -487,7 +487,7 @@ "property_type": "undef" }, { - "id": 11, + "id": 3, "dtype": "f32", "shape": [ 1, @@ -507,7 +507,7 @@ ], "outputs": [ { - "id": 12, + "id": 6, "dtype": "f32", "shape": [ 1, @@ -527,4 +527,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-simplified-f16-f32.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-simplified-f16-f32.json new file mode 100644 index 00000000000..f8d6014e0da --- /dev/null +++ b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-simplified-f16-f32.json @@ -0,0 +1,347 @@ +{ + "version": "3.8.0", + "engine_kind": "cpu", + "fpmath_mode": "strict", + "fpmath_mode_apply_to_int": "false", + "input_ports": [ + 1, + 2, + 4, + 5, + 3 + ], + "output_ports": [ + 6 + ], + "graph": [ + { + "id": 0, + "name": "matmul_qk", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 1 + } + }, + "inputs": [ + { + "id": 1, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 2, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 101, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 1, + "name": "scale_div", + "kind": "Divide", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 101, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 4, + "dtype": "f16", + "shape": [ + 1 + ], + "stride": [ + 1 + ], + "layout_type": "strided", + "property_type": "constant" + } + ], + "outputs": [ + { + "id": 102, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 2, + "name": "mask_add", + "kind": "Add", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 102, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 5, + "dtype": "f16", + "shape": [ + 1, + 1, + 384, + 384 + ], + "stride": [ + 147456, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 103, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 3, + "name": "softmax", + "kind": "SoftMax", + "attrs": { + "axis": { + "type": "s64", + "value": -1 + } + }, + "inputs": [ + { + "id": 103, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 104, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 4, + "name": "matmul_v", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 0 + } + }, + "inputs": [ + { + "id": 104, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 3, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 6, + "dtype": "f16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + } + ] +}