From a978f755e1e9c25108bcd9fe70b1e411a75046d1 Mon Sep 17 00:00:00 2001 From: "Guo, Xiang1" Date: Fri, 21 Mar 2025 04:10:01 -0700 Subject: [PATCH 1/2] graph: dnnl: add internal sdpa op --- src/graph/backend/dnnl/dnnl_op_def.hpp | 26 +++ src/graph/backend/dnnl/dnnl_opset.hpp | 1 + src/graph/backend/dnnl/dnnl_shape_infer.cpp | 57 ++++++ src/graph/backend/dnnl/dnnl_shape_infer.hpp | 4 + src/graph/backend/dnnl/internal_attrs.hpp | 8 + src/graph/backend/dnnl/internal_ops.hpp | 3 +- src/graph/backend/dnnl/layout_propagator.cpp | 30 +++ src/graph/backend/dnnl/layout_propagator.hpp | 1 + src/graph/backend/dnnl/op_executable.cpp | 23 +++ src/graph/backend/dnnl/op_executable.hpp | 181 +++++++++++++++++++ 10 files changed, 333 insertions(+), 1 deletion(-) diff --git a/src/graph/backend/dnnl/dnnl_op_def.hpp b/src/graph/backend/dnnl/dnnl_op_def.hpp index 66b6fefaa4b..be9118fee5d 100644 --- a/src/graph/backend/dnnl/dnnl_op_def.hpp +++ b/src/graph/backend/dnnl/dnnl_op_def.hpp @@ -1134,6 +1134,32 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_mask, 1, .SET_EXECUTABLE_CREATOR(executable_creator) .SET_ARG_INDICES_GETTER(memory_reparser_t)) +// The data types of query/key/value/mask/output must be consistent, and only +// f16/bf16 are supported. The data type of scale must be consistent with other +// input and output data types or fp32. +DNNL_GRAPH_OP_SCHEMA(dnnl_sdpa, 1, + op_schema_t() + .set_inputs_option(op_schema_t::param_num_option::variadic) + .set_num_inputs(std::set({3, 32})) + .set_num_outputs(2) + .set_input(0, "query") + .set_input(1, "key") + .set_input(2, "value") + .set_input(3, "scale") // optional + .set_input(4, "mask") // optional + .set_output(0, "output") + .set_output(1, "scratchpad") + .set_attr(op_attr::with_scale, true, attribute_kind::b) + .set_attr(op_attr::is_invert_scale, false, attribute_kind::b, + false) + .set_attr(op_attr::with_mask, true, attribute_kind::b) + // with_causal attribute support top-left mask type only + .set_attr(op_attr::with_causal, true, attribute_kind::b) + .set_shape_inference_function(infer_dnnl_sdpa_output_shape) + .SET_LAYOUT_PROPAGATOR(layout_propagator_for_sdpa) + .SET_EXECUTABLE_CREATOR(executable_creator) + .SET_ARG_INDICES_GETTER(sdpa_executable_t)) + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/dnnl_opset.hpp b/src/graph/backend/dnnl/dnnl_opset.hpp index cbfa412e657..0a6091ba616 100644 --- a/src/graph/backend/dnnl/dnnl_opset.hpp +++ b/src/graph/backend/dnnl/dnnl_opset.hpp @@ -97,6 +97,7 @@ class dnnl_opset_t { fn(get_op_schema()); fn(get_op_schema()); fn(get_op_schema()); + fn(get_op_schema()); } }; diff --git a/src/graph/backend/dnnl/dnnl_shape_infer.cpp b/src/graph/backend/dnnl/dnnl_shape_infer.cpp index b94ab7a87aa..e1cb7fc7d15 100644 --- a/src/graph/backend/dnnl/dnnl_shape_infer.cpp +++ b/src/graph/backend/dnnl/dnnl_shape_infer.cpp @@ -545,6 +545,63 @@ status_t infer_dnnl_binary_output_shape(op_t *n, } } +status_t infer_dnnl_sdpa_output_shape(op_t *n, + std::vector &inputs, + std::vector &outputs) { + // [batch_size, num_heads_q, seq_len_q, head_size_qk] + auto query = logical_tensor_wrapper_t(inputs[0]); + // [batch_size, num_heads_q, head_size_qk, seq_len_kv,] + auto key = logical_tensor_wrapper_t(inputs[1]); + // [batch_size, num_heads_v, seq_len_kv, head_size_v] + auto value = logical_tensor_wrapper_t(inputs[2]); + // [batch_size, num_heads_q, seq_len_q, head_size_v] + auto out0 = logical_tensor_wrapper_t(outputs[0]); + + dims query_dims = query.vdims(); + dims key_dims = key.vdims(); + dims value_dims = value.vdims(); + + VCHECK_INVALID_SHAPE((query_dims.size() == key_dims.size() + && key_dims.size() == value_dims.size()), + "%s, all input dims should match each other. input0 dims: %s, " + "input1 dims: %s, input2 dims: %s ", + op_t::kind2str(n->get_kind()).c_str(), dims2str(query_dims).c_str(), + dims2str(key_dims).c_str(), dims2str(value_dims).c_str()); + + VCHECK_INVALID_SHAPE((query_dims.size() == 4), + "%s, only support 4D input for all q/k/v. input0 dimension: %s, " + "input1 dimension: %s, input2 dimension: %s ", + op_t::kind2str(n->get_kind()).c_str(), + std::to_string(query_dims.size()).c_str(), + std::to_string(key_dims.size()).c_str(), + std::to_string(value_dims.size()).c_str()); + + VCHECK_INVALID_SHAPE((query_dims[3] == key_dims[2]), + "%s, query head size should be match with key head size. query " + "dims: %s, Key dims: %s", + op_t::kind2str(n->get_kind()).c_str(), dims2str(query_dims).c_str(), + dims2str(key_dims).c_str()); + + VCHECK_INVALID_SHAPE((key_dims[3] == value_dims[2]), + "%s, key sequence length should be match with value sequence " + "length. key dims: %s, value dims: %s ", + op_t::kind2str(n->get_kind()).c_str(), dims2str(key_dims).c_str(), + dims2str(value_dims).c_str()); + + dims inferred_output_shape; + inferred_output_shape + = {query_dims[0], query_dims[1], query_dims[2], value_dims[3]}; + + if (out0.ndims() != -1) { + VCHECK_INVALID_SHAPE(validate(inferred_output_shape, out0.vdims()), + "%s, inferred out shape and output shape are not compatible", + op_t::kind2str(n->get_kind()).c_str()); + } + + set_shape_and_strides(*outputs[0], inferred_output_shape); + return status::success; +} + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/dnnl_shape_infer.hpp b/src/graph/backend/dnnl/dnnl_shape_infer.hpp index 78368597062..0877dc26c11 100644 --- a/src/graph/backend/dnnl/dnnl_shape_infer.hpp +++ b/src/graph/backend/dnnl/dnnl_shape_infer.hpp @@ -107,6 +107,10 @@ status_t infer_binary_select_output_shape(op_t *n, std::vector &inputs, std::vector &outputs); +status_t infer_dnnl_sdpa_output_shape(op_t *n, + std::vector &inputs, + std::vector &outputs); + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/internal_attrs.hpp b/src/graph/backend/dnnl/internal_attrs.hpp index 164c73b385c..2312d0ec7a5 100644 --- a/src/graph/backend/dnnl/internal_attrs.hpp +++ b/src/graph/backend/dnnl/internal_attrs.hpp @@ -45,6 +45,10 @@ const op_attr_t with_runtime_dst_zps = 0x1000c; const op_attr_t is_bias_add = 0x1000d; const op_attr_t with_sum = 0x1000e; const op_attr_t keep_dst_layout = 0x1000f; +const op_attr_t with_scale = 0x10010; +const op_attr_t is_invert_scale = 0x10011; +const op_attr_t with_causal = 0x10012; +const op_attr_t with_mask = 0x10013; // int64_t const op_attr_t alg_kind = 0x10100; @@ -86,6 +90,10 @@ static inline std::string internal_attr2str(op_attr_t attr) { CASE(is_bias_add); CASE(with_sum); CASE(keep_dst_layout); + CASE(with_scale); + CASE(is_invert_scale); + CASE(with_causal); + CASE(with_mask); CASE(alg_kind); CASE(fusion_info_key); CASE(axis_row); diff --git a/src/graph/backend/dnnl/internal_ops.hpp b/src/graph/backend/dnnl/internal_ops.hpp index f3bf2462c69..db9fe1a5c70 100644 --- a/src/graph/backend/dnnl/internal_ops.hpp +++ b/src/graph/backend/dnnl/internal_ops.hpp @@ -79,7 +79,8 @@ namespace op_kind { X(dnnl_convtranspose_bwd_weights, Dnnl_convtranspose_bwd_weights) \ X(dnnl_groupnorm, Dnnl_groupnorm) \ X(dnnl_gen_index, Dnnl_gen_index) \ - X(dnnl_mask, Dnnl_mask) + X(dnnl_mask, Dnnl_mask) \ + X(dnnl_sdpa, Dnnl_sdpa) enum kind_t { kDNNL_INTERNAL_OP_STARTER = 0x1234, diff --git a/src/graph/backend/dnnl/layout_propagator.cpp b/src/graph/backend/dnnl/layout_propagator.cpp index 57e24d6decf..c7a266e7aff 100644 --- a/src/graph/backend/dnnl/layout_propagator.cpp +++ b/src/graph/backend/dnnl/layout_propagator.cpp @@ -1568,6 +1568,36 @@ status_t layout_propagator_for_mask(std::shared_ptr &op, return status; } +status_t layout_propagator_for_sdpa(std::shared_ptr &op, + const dnnl::engine &p_engine, fusion_info_mgr_t &mgr, + pd_cache_t &pd_cache, subgraph_rewriter_t &rewriter) { + UNUSED(p_engine); + UNUSED(mgr); + UNUSED(pd_cache); + UNUSED(rewriter); + + value_ptr dst_val = op->get_output_value(0); + const logical_tensor_t &out_lt = dst_val->get_logical_tensor(); + + dnnl::memory::desc expected_md; + // Set default output layout format for sdpa as acbd if user doesn't specify + // the layout since no reorder will required after sdpa. + if (ltw(out_lt).is_any()) { + expected_md = {ltw(out_lt).vdims(), + static_cast(ltw(out_lt).data_type()), + dnnl::memory::format_tag::acbd}; + } else { + expected_md = make_dnnl_memory_desc(out_lt); + } + status_t status = fill_layout_info(dst_val, expected_md); + + // fill scratchpads dimensions and data type to scratchpad value_t + value_ptr scratchpad_val = op->get_output_value(1); + const memory::desc scratchpad_desc; + status = fill_layout_info(scratchpad_val, scratchpad_desc); + return status; +} + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/layout_propagator.hpp b/src/graph/backend/dnnl/layout_propagator.hpp index db6c8b13218..40b593ea551 100644 --- a/src/graph/backend/dnnl/layout_propagator.hpp +++ b/src/graph/backend/dnnl/layout_propagator.hpp @@ -93,6 +93,7 @@ DECLARE_LAYOUT_PROPAGATOR(add_zps); DECLARE_LAYOUT_PROPAGATOR(groupnorm); DECLARE_LAYOUT_PROPAGATOR(gen_index); DECLARE_LAYOUT_PROPAGATOR(mask); +DECLARE_LAYOUT_PROPAGATOR(sdpa); #undef DECLARE_LAYOUT_PROPAGATOR diff --git a/src/graph/backend/dnnl/op_executable.cpp b/src/graph/backend/dnnl/op_executable.cpp index a7899754d76..349551f7063 100644 --- a/src/graph/backend/dnnl/op_executable.cpp +++ b/src/graph/backend/dnnl/op_executable.cpp @@ -2405,6 +2405,29 @@ arg_indices_t genindex_executable_t::get_arg_indices( return arg_indices; } +arg_indices_t sdpa_executable_t::get_arg_indices( + const op_t *op, fusion_info_mgr_t &mgr) { + UNUSED(mgr); + + arg_indices_t arg_indices; + // add input args + size_t index = 0; + arg_indices.insert({DNNL_ARG_QUERIES, indices_t {input, index++}}); + arg_indices.insert({DNNL_ARG_KEYS, indices_t {input, index++}}); + arg_indices.insert({DNNL_ARG_VALUES, indices_t {input, index++}}); + if (op->get_attr(dnnl::impl::graph::dnnl_impl::op_attr::with_scale)) { + arg_indices.insert({DNNL_ARG_SCALE, indices_t {input, index++}}); + } + if (op->get_attr(dnnl::impl::graph::dnnl_impl::op_attr::with_mask)) { + arg_indices.insert({DNNL_ARG_ATTN_MASK, indices_t {input, index++}}); + } + + // add output args + arg_indices.insert({DNNL_ARG_DST, indices_t {output, 0}}); + arg_indices.insert({DNNL_ARG_SCRATCHPAD, indices_t {output, 1}}); + return arg_indices; +} + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/op_executable.hpp b/src/graph/backend/dnnl/op_executable.hpp index a227efdff61..203509094f2 100644 --- a/src/graph/backend/dnnl/op_executable.hpp +++ b/src/graph/backend/dnnl/op_executable.hpp @@ -24,6 +24,9 @@ #include #include +#include "common/primitive.hpp" +#include "common/sdpa_utils.hpp" + #include "oneapi/dnnl/dnnl.hpp" #ifdef DNNL_WITH_SYCL #include "oneapi/dnnl/dnnl_sycl.hpp" @@ -2637,6 +2640,184 @@ struct genindex_executable_t : public op_executable_t { #endif }; +struct sdpa_executable_t : public op_executable_t { + DECLARE_ARG_INDICES_GETTER; + + sdpa_executable_t(std::shared_ptr &op, const dnnl::engine &p_engine, + fusion_info_mgr_t &mgr, pd_cache_t &pd_cache) + : with_scale_(op->get_attr(op_attr::with_scale)) + , with_mask_(op->get_attr(op_attr::with_mask)) + , is_causal_mask_(op->get_attr(op_attr::with_causal)) { + + auto md_q = make_dnnl_memory_desc( + op->get_input_value(0)->get_logical_tensor()); + auto md_k = make_dnnl_memory_desc( + op->get_input_value(1)->get_logical_tensor()); + auto md_v = make_dnnl_memory_desc( + op->get_input_value(2)->get_logical_tensor()); + auto md_dst = make_dnnl_memory_desc( + op->get_output_value(0)->get_logical_tensor()); + + auto scale_dt = impl::data_type::undef; + size_t idx = 3; + if (with_scale_) + scale_dt = op->get_input_value(idx++) + ->get_logical_tensor() + .data_type; + + dnnl::memory::desc md_mask; + if (with_mask_) + md_mask = make_dnnl_memory_desc( + op->get_input_value(idx++)->get_logical_tensor()); + + dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + attr.set_fpmath_mode( + static_cast(mgr.get_fpmath_mode().mode_)); + if (op->has_attr(op_attr::is_invert_scale)) + is_invert_scale_ = op->get_attr(op_attr::is_invert_scale); + + dim_t kv_head_number + = op->get_input_value(1)->get_logical_tensor().dims[1]; + status_t s = create_sdpa_pd(sdpa_pd_, p_engine.get(), md_q.get(), + md_k.get(), md_v.get(), md_dst.get(), md_mask.get(), scale_dt, + is_invert_scale_, kv_head_number, is_causal_mask_, attr.get()); + if (s != dnnl::impl::status::success) { + is_initialized_ = false; + } else { + status_t s = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get()); + is_initialized_ = s == status::success ? true : false; + } + } + + bool is_initialized() const { return is_initialized_; } + + void execute(const stream &stream, + const std::unordered_map &args) const override { + exec_args_t exec_args; + memory_arg_t mem_arg_q = {(args.at(DNNL_ARG_QUERIES)).get(), true}; + memory_arg_t mem_arg_k = {(args.at(DNNL_ARG_KEYS)).get(), true}; + memory_arg_t mem_arg_v = {(args.at(DNNL_ARG_VALUES)).get(), true}; + memory_arg_t mem_arg_dst = {(args.at(DNNL_ARG_DST)).get(), false}; + memory_arg_t mem_arg_scale = { + with_scale_ ? (args.at(DNNL_ARG_SCALE)).get() : nullptr, true}; + memory_arg_t mem_arg_mask + = {with_mask_ ? (args.at(DNNL_ARG_ATTN_MASK)).get() : nullptr, + true}; + + exec_args[DNNL_ARG_QUERIES] = mem_arg_q; + exec_args[DNNL_ARG_KEYS] = mem_arg_k; + exec_args[DNNL_ARG_VALUES] = mem_arg_v; + exec_args[DNNL_ARG_DST] = mem_arg_dst; + exec_args[DNNL_ARG_SCALE] = mem_arg_scale; + exec_args[DNNL_ARG_ATTN_MASK] = mem_arg_mask; + + exec_ctx_t ctx(stream.get(), std::move(exec_args)); + sdpa_prim_->execute(ctx); + } + +#ifdef DNNL_WITH_SYCL + ::sycl::event execute_sycl(const stream &stream, + const std::unordered_map &args, + const std::vector<::sycl::event> &deps) const override { + +#if DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL + return status::unimplemented; +#endif + + exec_args_t exec_args; + memory_arg_t mem_arg_q = {(args.at(DNNL_ARG_QUERIES)).get(), true}; + memory_arg_t mem_arg_k = {(args.at(DNNL_ARG_KEYS)).get(), true}; + memory_arg_t mem_arg_v = {(args.at(DNNL_ARG_VALUES)).get(), true}; + memory_arg_t mem_arg_dst = {(args.at(DNNL_ARG_DST)).get(), false}; + memory_arg_t mem_arg_scale = { + with_scale_ ? (args.at(DNNL_ARG_SCALE)).get() : nullptr, true}; + memory_arg_t mem_arg_mask + = {with_mask_ ? (args.at(DNNL_ARG_ATTN_MASK)).get() : nullptr, + true}; + + exec_args[DNNL_ARG_QUERIES] = mem_arg_q; + exec_args[DNNL_ARG_KEYS] = mem_arg_k; + exec_args[DNNL_ARG_VALUES] = mem_arg_v; + exec_args[DNNL_ARG_DST] = mem_arg_dst; + exec_args[DNNL_ARG_SCALE] = mem_arg_scale; + exec_args[DNNL_ARG_ATTN_MASK] = mem_arg_mask; + + exec_ctx_t ctx(stream.get(), std::move(exec_args)); + auto *sycl_stream + = dnnl::impl::utils::downcast(stream.get()); + sycl_stream->before_exec_hook(); + + if (!deps.empty()) sycl_stream->sycl_ctx().set_deps(deps); + + sdpa_prim_->execute(ctx); + + auto return_event = sycl_stream->get_output_event(); + sycl_stream->after_exec_hook(); + return return_event; + } +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL + cl_event execute_ocl(const stream &stream, + const std::unordered_map &args, + const std::vector &deps) const override { + exec_args_t exec_args; + memory_arg_t mem_arg_q = {(args.at(DNNL_ARG_QUERIES)).get(), true}; + memory_arg_t mem_arg_k = {(args.at(DNNL_ARG_KEYS)).get(), true}; + memory_arg_t mem_arg_v = {(args.at(DNNL_ARG_VALUES)).get(), true}; + memory_arg_t mem_arg_dst = {(args.at(DNNL_ARG_DST)).get(), false}; + memory_arg_t mem_arg_scale = { + with_scale_ ? (args.at(DNNL_ARG_SCALE)).get() : nullptr, true}; + memory_arg_t mem_arg_mask + = {with_mask_ ? (args.at(DNNL_ARG_ATTN_MASK)).get() : nullptr, + true}; + + exec_args[DNNL_ARG_QUERIES] = mem_arg_q; + exec_args[DNNL_ARG_KEYS] = mem_arg_k; + exec_args[DNNL_ARG_VALUES] = mem_arg_v; + exec_args[DNNL_ARG_DST] = mem_arg_dst; + exec_args[DNNL_ARG_SCALE] = mem_arg_scale; + exec_args[DNNL_ARG_ATTN_MASK] = mem_arg_mask; + + exec_ctx_t ctx(stream.get(), std::move(exec_args)); + + auto *ocl_stream + = dnnl::impl::utils::downcast( + stream.get()); + + ocl_stream->before_exec_hook(); + + if (!deps.empty()) { + std::vector> events(deps.size()); + for (size_t i = 0; i < deps.size(); i++) + events[i] = xpu::ocl::wrapper_t(deps[i], true); + ocl_stream->ocl_ctx().set_deps(events); + } + + sdpa_prim_->execute(ctx); + + cl_event return_event = nullptr; + if ((ocl_stream->flags() & stream_flags::in_order) == 0) { + auto last = ocl_stream->get_output_event(); + return_event = last.release(); + } + + ocl_stream->after_exec_hook(); + return return_event; + } +#endif + +private: + std::shared_ptr sdpa_pd_; + std::shared_ptr sdpa_prim_; + bool with_scale_; + bool with_mask_; + bool is_invert_scale_; + bool is_causal_mask_; + bool is_initialized_; +}; + } // namespace dnnl_impl } // namespace graph } // namespace impl From 71905d2b9ffa119dcc9621eeefcd90e0a358761e Mon Sep 17 00:00:00 2001 From: "Guo, Xiang1" Date: Fri, 21 Mar 2025 04:24:45 -0700 Subject: [PATCH 2/2] graph: dnnl: add sdpa primitive ukernel v1 --- .../backend/dnnl/kernels/large_partition.cpp | 2 +- src/graph/backend/dnnl/kernels/matmul.cpp | 2 +- src/graph/backend/dnnl/kernels/mqa_decomp.cpp | 2 +- src/graph/backend/dnnl/kernels/sdp.hpp | 13 +- src/graph/backend/dnnl/kernels/sdp_decomp.cpp | 2 +- .../backend/dnnl/kernels/sdp_primitive.cpp | 2 +- .../dnnl/kernels/sdp_primitive_config.cpp | 20 +- .../dnnl/kernels/sdp_primitive_config.hpp | 3 +- .../backend/dnnl/kernels/sdp_primitive_v1.cpp | 228 ++++++++++++++++++ .../backend/dnnl/kernels/sdp_primitive_v1.hpp | 103 ++++++++ src/graph/backend/dnnl/passes/compile_ops.cpp | 11 +- src/graph/backend/dnnl/passes/transform.cpp | 144 ++++++++++- src/graph/backend/dnnl/passes/transform.hpp | 10 +- 13 files changed, 522 insertions(+), 20 deletions(-) create mode 100644 src/graph/backend/dnnl/kernels/sdp_primitive_v1.cpp create mode 100644 src/graph/backend/dnnl/kernels/sdp_primitive_v1.hpp diff --git a/src/graph/backend/dnnl/kernels/large_partition.cpp b/src/graph/backend/dnnl/kernels/large_partition.cpp index 340b575cf40..d14b4b4ae4d 100644 --- a/src/graph/backend/dnnl/kernels/large_partition.cpp +++ b/src/graph/backend/dnnl/kernels/large_partition.cpp @@ -142,7 +142,7 @@ void larger_partition_kernel_t::setup_pipeline_stage2(pass_pipeline_t &pipeline, } BACKEND_DNNL_ADD_PASS(pipeline, infer_shape); BACKEND_DNNL_ADD_PASS(pipeline, fuse_src_transpose_to_matmul); - BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor); BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation); BACKEND_DNNL_ADD_PASS(pipeline, common_reorder_elimination); BACKEND_DNNL_ADD_PASS(pipeline, fuse_adjacent_reorders); diff --git a/src/graph/backend/dnnl/kernels/matmul.cpp b/src/graph/backend/dnnl/kernels/matmul.cpp index 17005554cba..b8fffd81202 100644 --- a/src/graph/backend/dnnl/kernels/matmul.cpp +++ b/src/graph/backend/dnnl/kernels/matmul.cpp @@ -110,7 +110,7 @@ status_t matmul_t::compile_impl(const dnnl_partition_impl_t *part, } BACKEND_DNNL_ADD_PASS(pipeline, infer_shape); - BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor); BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation); BACKEND_DNNL_ADD_PASS(pipeline, fuse_adjacent_reorders); diff --git a/src/graph/backend/dnnl/kernels/mqa_decomp.cpp b/src/graph/backend/dnnl/kernels/mqa_decomp.cpp index 9ec0bafe82c..aa76e48ec2b 100644 --- a/src/graph/backend/dnnl/kernels/mqa_decomp.cpp +++ b/src/graph/backend/dnnl/kernels/mqa_decomp.cpp @@ -87,7 +87,7 @@ status_t mqa_decomp_kernel_t::compile_impl( BACKEND_DNNL_ADD_PASS(pipeline, remove_quant_data_with_no_effect); } pipeline.reset_visualize_arg(true, false); - BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor); BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation); // Run the added passes diff --git a/src/graph/backend/dnnl/kernels/sdp.hpp b/src/graph/backend/dnnl/kernels/sdp.hpp index 9df511dfdaa..15549e12d48 100644 --- a/src/graph/backend/dnnl/kernels/sdp.hpp +++ b/src/graph/backend/dnnl/kernels/sdp.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2024-2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ #include "graph/backend/dnnl/kernels/large_partition.hpp" #include "graph/backend/dnnl/kernels/sdp_decomp.hpp" #include "graph/backend/dnnl/kernels/sdp_primitive.hpp" +#include "graph/backend/dnnl/kernels/sdp_primitive_v1.hpp" #include "graph/backend/dnnl/dnnl_partition_impl.hpp" @@ -65,7 +66,15 @@ struct sdp_base_t : public kernel_base_t { status_t ret = status::unimplemented; - if (enable_ukernel) { + // SDPA Ukernel v1 with fused internal sdpa solution. Support fload sdpa + // only. + // TODO(GX): Support quantized sdpa and merge with sdp_primitive_kernel_t. + if (enable_ukernel && !quantized) { + kernel = std::make_shared(); + ret = kernel->compile_impl(part, g_engine, inputs, outputs); + } + + if (ret != status::success && enable_ukernel) { kernel = std::make_shared>(); ret = kernel->compile_impl(part, g_engine, inputs, outputs); } diff --git a/src/graph/backend/dnnl/kernels/sdp_decomp.cpp b/src/graph/backend/dnnl/kernels/sdp_decomp.cpp index 9e1361d7add..cf87009c1d7 100644 --- a/src/graph/backend/dnnl/kernels/sdp_decomp.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_decomp.cpp @@ -86,7 +86,7 @@ status_t sdp_decomp_kernel_t::compile_impl( BACKEND_DNNL_ADD_PASS(pipeline, remove_quant_data_with_no_effect); } pipeline.reset_visualize_arg(true, false); - BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor); BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation); // Run the added passes diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive.cpp index 3bf0c2a9536..d7c244b4a98 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive.cpp @@ -92,7 +92,7 @@ status_t sdp_primitive_kernel_t::compile_impl( pipeline.reset_visualize_arg(true, false); BACKEND_DNNL_ADD_PASS(pipeline, infer_shape); - BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor); BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation); // bind the memory for each op diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp index b05ddf4004d..db1d34e9fc6 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp @@ -166,11 +166,29 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr &sg, status_t sdp_primitive_config_t::initial_check( const std::shared_ptr &sg, - const std::vector &inputs) { + const std::vector &inputs, bool v1_kernel) { // At least 3 inputs: Q, K, V VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments, "At least 3 inputs are required"); + // Ukernel doesn't support f32 datatype now + VCHECK_SDP_PRIMITIVE(inputs[0].data_type != dnnl_data_type_t::dnnl_f32, + status::invalid_arguments, + "SDPA ukernel doesn't support f32 datatype now"); + + // Note: sdpa_primitive_v1 kernel currently don't support legacy GQA pattern. + if (v1_kernel) { + for (auto &cur_op : sg->get_ops()) { + if (cur_op->get_kind() == graph::op_kind::StaticReshape) { + auto in = cur_op->get_input_value(0)->get_logical_tensor(); + auto out = cur_op->get_output_value(0)->get_logical_tensor(); + if (ltw(in).ndims() == 5 || ltw(out).ndims() == 5) { + return status::unimplemented; + } + } + } + } + // step1(pattern check): Not support sdpa variants with select as mask // We already have a pattern matcher to ensure that the sdpa patterns // dispatch to here are knows ones, and we have quant check in sdpa base diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp b/src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp index abea994ab01..e1f77232abd 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp @@ -82,7 +82,8 @@ struct sdp_primitive_config_t { // 2. only support fp16 data type // 3. only support 4-dims tensor status_t initial_check(const std::shared_ptr &sg, - const std::vector &inputs); + const std::vector &inputs, + bool v1_kernel = false); // Initialize parameters and primitive. status_t init(std::shared_ptr &sg, const dnnl::engine &p_engine, diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_v1.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive_v1.cpp new file mode 100644 index 00000000000..f2544ee0f86 --- /dev/null +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_v1.cpp @@ -0,0 +1,228 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "graph/backend/dnnl/kernels/sdp_primitive_v1.hpp" + +#include "common/sdpa_pd.hpp" + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL +#include "gpu/intel/ocl/stream.hpp" +#elif DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL +#include "gpu/intel/sycl/stream.hpp" +#endif + +#include "graph/backend/dnnl/passes/compile_ops.hpp" +#include "graph/backend/dnnl/passes/constant_propagation.hpp" +#include "graph/backend/dnnl/passes/insert_ops.hpp" +#include "graph/backend/dnnl/passes/layout_propagation.hpp" +#include "graph/backend/dnnl/passes/lower.hpp" +#include "graph/backend/dnnl/passes/memory_planning.hpp" +#include "graph/backend/dnnl/passes/transform.hpp" +#include "graph/backend/dnnl/passes/utils.hpp" + +#include "graph/backend/dnnl/op_executable.hpp" + +namespace dnnl { +namespace impl { +namespace graph { +namespace dnnl_impl { + +status_t sdp_primitive_v1_kernel_t::compile_impl( + const dnnl_partition_impl_t *part, const engine_t *g_engine, + const std::vector &inputs, + const std::vector &outputs) { +// sdp_primitive_v1_kernel_t only supports Intel GPU. +#if defined(DNNL_WITH_SYCL) && DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL + return status::unimplemented; +#endif + + p_engine_ = make_dnnl_engine(*g_engine); + g_alloc_ + = reinterpret_cast(g_engine->get_allocator()); + + // First, dry run on a deep copy + subgraph_ + = std::make_shared(graph_t::deep_copy(part->get_ops()), + p_engine_, part->get_fpmath_mode(), false, true); + CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs)); + + CHECK(cfg_.initial_check(subgraph_, inputs, true)); + + subgraph_visualizer_t vis(part->id(), [this](const value_t *val) { + return this->memory_planner_.get_memory_info(val); + }); + pass_pipeline_t pipeline = pass_pipeline_t(vis); + + BACKEND_DNNL_ADD_PASS(pipeline, lower_down); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_implicit_causal_mask); + BACKEND_DNNL_ADD_PASS(pipeline, binary_canonicalization); + BACKEND_DNNL_ADD_PASS(pipeline, insert_permute_for_matmul); + + pipeline.reset_visualize_arg(true, false); + BACKEND_DNNL_ADD_PASS(pipeline, infer_shape); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_src_transpose_to_matmul); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_sdpa); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor); + BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation); + + // bind the memory for each op` + auto memory_plan = [&](std::shared_ptr &sg) { + return memory_planner_.run(sg); + }; + pipeline.reset_visualize_arg(true, true); + BACKEND_DNNL_ADD_PASS(pipeline, memory_plan); + BACKEND_DNNL_ADD_PASS(pipeline, compile_ops); + + // Run the added passes + BACKEND_DNNL_CHECK(pipeline.run(subgraph_)); + + // fill information for inputs logical tensors + for (size_t i = 0; i < inputs.size(); i++) { + auto &in = const_cast(inputs[i]); + in = subgraph_->ins_[i]; + } + + // fill information for outputs logical tensors + for (size_t i = 0; i < outputs.size(); i++) { + auto &out = const_cast(outputs[i]); + out = subgraph_->outs_[i]; + } + + resource_ctor_ = [this]() { + return this->memory_planner_.get_exec_args_set().clone(); + }; + + return status::success; +} + +void sdp_primitive_v1_kernel_t::prepare_args_set( + const execution_args_set_t *res, const std::vector &inputs, + const std::vector &outputs, const scratchpad_t &scratchpad) { + // update the data of partition in/outputs args + for (const auto &mem_idx : res->get_mems_use_external_inputs()) { + mem_idx.first.set_data_handle(inputs[mem_idx.second].get_data_handle()); + } + for (const auto &mem_idx : res->get_mems_use_external_outputs()) { + mem_idx.first.set_data_handle( + outputs[mem_idx.second].get_data_handle()); + } + + grantor_t var_grantor = memory_planner_.internal_temporary_grantor( + scratchpad.get_buffer()); + + for (auto &mem_offkey : res->get_mems_use_internal_temporary()) { + mem_offkey.first.set_data_handle(var_grantor.get(mem_offkey.second)); + } +} + +status_t sdp_primitive_v1_kernel_t::execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs) { + dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream); + + thread_local_cache_t res_cache; + execution_args_set_t *res = res_cache.get_or_add( + reinterpret_cast(this), resource_ctor_); + + temporary_scratchpad_t scratchpad( + memory_planner_.total_internal_temporary_size(), p_engine_, + *g_alloc_); + prepare_args_set(res, inputs, outputs, scratchpad); + + for (size_t i = 0; i < subgraph_->execs_.size(); i++) { + subgraph_->execs_[i]->execute(p_stream, res->get_exec_args()[i]); + } + + return status::success; +} + +#ifdef DNNL_WITH_SYCL +status_t sdp_primitive_v1_kernel_t::sycl_execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector<::sycl::event> &sycl_deps, + ::sycl::event *sycl_event) { +// sdp_primitive_v1_kernel_t only supports Intel GPU. +#if DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL + return status::unimplemented; +#endif + auto deps = sycl_deps; + ::sycl::event returned_event; + + dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream); + + thread_local_cache_t res_cache; + execution_args_set_t *res = res_cache.get_or_add( + reinterpret_cast(this), resource_ctor_); + + temporary_scratchpad_t scratchpad( + memory_planner_.total_internal_temporary_size(), p_engine_, + *g_alloc_); + prepare_args_set(res, inputs, outputs, scratchpad); + + for (size_t i = 0; i < subgraph_->execs_.size(); i++) { + if (subgraph_->is_constant_[i]) continue; + returned_event = subgraph_->execs_[i]->execute_sycl( + p_stream, res->get_exec_args()[i], deps); + deps = {returned_event}; + } + + scratchpad.set_deps(returned_event); + if (sycl_event) *sycl_event = returned_event; + + return status::success; +} +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL +status_t sdp_primitive_v1_kernel_t::ocl_execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &cl_deps, cl_event *ret_event) { + auto deps = cl_deps; + cl_event returned_event {}; + + dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream); + + thread_local_cache_t res_cache; + execution_args_set_t *res = res_cache.get_or_add( + reinterpret_cast(this), resource_ctor_); + + temporary_scratchpad_t scratchpad( + memory_planner_.total_internal_temporary_size(), p_engine_, + *g_alloc_); + prepare_args_set(res, inputs, outputs, scratchpad); + + for (size_t i = 0; i < subgraph_->execs_.size(); i++) { + if (subgraph_->is_constant_[i]) continue; + returned_event = subgraph_->execs_[i]->execute_ocl( + p_stream, res->get_exec_args()[i], deps); + deps = {returned_event}; + } + + scratchpad.set_deps(returned_event); + if (ret_event) *ret_event = returned_event; + + return status::success; +} +#endif + +struct sdp_primitive_v1_kernel_t; + +} // namespace dnnl_impl +} // namespace graph +} // namespace impl +} // namespace dnnl diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_v1.hpp b/src/graph/backend/dnnl/kernels/sdp_primitive_v1.hpp new file mode 100644 index 00000000000..1d54d0ec36f --- /dev/null +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_v1.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GRAPH_BACKEND_DNNL_KERNELS_SDP_PRIMITIVE_V1_HPP +#define GRAPH_BACKEND_DNNL_KERNELS_SDP_PRIMITIVE_V1_HPP + +#include +#include +#include +#include +#include + +#include "graph/backend/dnnl/kernels/sdp_primitive_config.hpp" + +#include "graph/backend/dnnl/common.hpp" +#include "graph/backend/dnnl/dnnl_constant_tensor_cache.hpp" +#include "graph/backend/dnnl/dnnl_partition_impl.hpp" +#include "graph/backend/dnnl/op_executable.hpp" +#include "graph/backend/dnnl/scratchpad.hpp" +#include "graph/backend/dnnl/thread_local_cache.hpp" +#include "graph/backend/dnnl/utils.hpp" + +#include "graph/backend/dnnl/passes/memory_planning.hpp" + +namespace dnnl { +namespace impl { +namespace graph { +namespace dnnl_impl { + +struct sdp_primitive_v1_kernel_t : public kernel_base_t { +private: + allocator_t *g_alloc_ = nullptr; + + std::shared_ptr subgraph_; + memory_planner_t memory_planner_; + std::function()> resource_ctor_; + + sdp_primitive_config_t cfg_; + +public: + sdp_primitive_v1_kernel_t() { + thread_local_cache_t res_cache; + res_cache.retain(); + } + + ~sdp_primitive_v1_kernel_t() override { + thread_local_cache_t res_cache; + res_cache.remove_if_exist(reinterpret_cast(this)); + res_cache.release(); + } + + status_t compile_impl(const dnnl_partition_impl_t *part, + const engine_t *g_engine, + const std::vector &inputs, + const std::vector &outputs) override; + + void prepare_args_set(const execution_args_set_t *res, + const std::vector &inputs, + const std::vector &outputs, + const scratchpad_t &scratchpad); + + status_t execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs) override; + +#ifdef DNNL_WITH_SYCL + status_t sycl_execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector<::sycl::event> &sycl_deps, + ::sycl::event *sycl_event) override; +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL + status_t ocl_execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &cl_deps, cl_event *ret_event) override; +#endif + + DEF_KERNEL_METHOD_STR(sdp_primitive_v1_kernel_t) + DNNL_DISALLOW_COPY_AND_ASSIGN(sdp_primitive_v1_kernel_t) +}; + +} // namespace dnnl_impl +} // namespace graph +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/graph/backend/dnnl/passes/compile_ops.cpp b/src/graph/backend/dnnl/passes/compile_ops.cpp index d3ac4c34b0f..314859fb915 100644 --- a/src/graph/backend/dnnl/passes/compile_ops.cpp +++ b/src/graph/backend/dnnl/passes/compile_ops.cpp @@ -59,14 +59,21 @@ status_t compile_ops(std::shared_ptr &sg) { auto cur_op = op->shared_from_this(); auto creator = opm->get_additional_item( "executable_creator"); + std::shared_ptr exec = creator(cur_op, p_engine, mgr, pd_cache); - VCHECK_COMPILE_OPS(exec != nullptr, status::invalid_graph_op, "unimplemented op, can't compile op %s", op->get_name().c_str()); - + if (cur_op->get_kind() == op_kind::dnnl_sdpa) { + auto sdpa_exec = std::dynamic_pointer_cast(exec); + VCHECK_COMPILE_OPS(sdpa_exec->is_initialized(), + status::unimplemented, + "failed to create executable for op %s", + op->get_name().c_str()); + } sg->execs_.emplace_back(exec); + sg->is_constant_.push_back(op->has_attr(op_attr::is_constant) && op->get_attr(op_attr::is_constant)); return status::success; diff --git a/src/graph/backend/dnnl/passes/transform.cpp b/src/graph/backend/dnnl/passes/transform.cpp index 50bc177392c..edfd8ee0084 100644 --- a/src/graph/backend/dnnl/passes/transform.cpp +++ b/src/graph/backend/dnnl/passes/transform.cpp @@ -3835,13 +3835,16 @@ impl::status_t fuse_src_transpose_to_matmul(std::shared_ptr &sg) { return impl::status::success; } -impl::status_t fuse_dst_transpose_to_matmul(std::shared_ptr &sg) { +impl::status_t fuse_dst_transpose_to_predecessor( + std::shared_ptr &sg) { std::vector transpose_ops; for (auto &cur_op : sg->get_ops()) { if (cur_op->get_kind() == op_kind::dnnl_transpose && cur_op->get_input_value(0)->has_producer() - && cur_op->get_input_value(0)->get_producer().get_kind() - == op_kind::dnnl_matmul + && (cur_op->get_input_value(0)->get_producer().get_kind() + == op_kind::dnnl_matmul + || cur_op->get_input_value(0)->get_producer().get_kind() + == op_kind::dnnl_sdpa) && !cur_op->get_output_value(0)->get_consumers().empty() && (cur_op->get_output_value(0) ->get_consumers()[0] @@ -3894,13 +3897,17 @@ impl::status_t fuse_dst_transpose_to_matmul(std::shared_ptr &sg) { dnnl::memory::desc expected_out_md = out_md.permute_axes(axes); // Special check to avoid low matmul performance with adbc layout. // TODO: remove this once the performance is improved. - if (get_format_tag(expected_out_md) == dnnl::memory::format_tag::adbc) { + if (in_val->get_producer().get_kind() == op_kind::dnnl_matmul + && get_format_tag(expected_out_md) + == dnnl::memory::format_tag::adbc) { break; } const auto &strides = expected_out_md.get_strides(); in_val->set_strides(strides); - auto &matmul = transpose_op->get_input_value(0)->get_producer(); - matmul.set_attr(op_attr::keep_dst_layout, true); + if (in_val->get_producer().get_kind() == op_kind::dnnl_matmul) { + auto &matmul = in_val->get_producer(); + matmul.set_attr(op_attr::keep_dst_layout, true); + } } rewriter.run(); return impl::status::success; @@ -4168,6 +4175,131 @@ impl::status_t replace_select_values(std::shared_ptr &sg) { return infer_shape(sg); } +status_t fuse_sdpa(std::shared_ptr &sg) { + std::vector candidates; + for (auto &cur_op : sg->get_ops()) { + std::vector pattern_ops; + if (cur_op->get_kind() != op_kind::dnnl_matmul) continue; + op_ptr walker = cur_op; + bool valid_pattern = true; + bool has_scale = false, has_mask = false, has_softmax = false; + bool finished = false; + while (walker && !finished) { + pattern_ops.push_back(walker); + switch (walker->get_kind()) { + case op_kind::dnnl_matmul: { + if (pattern_ops.size() == 1) { + } + // Finish pattern match process after second matmul + else { + valid_pattern = (pattern_ops.size() >= 3); + finished = true; + } + break; + } + case op_kind::dnnl_binary: { + auto alg = static_cast( + walker->get_attr(op_attr::alg_kind)); + if (alg == dnnl::algorithm::binary_mul + || alg == dnnl::algorithm::binary_div) { + if (has_scale) valid_pattern = false; + has_scale = true; + } else if (alg == dnnl::algorithm::binary_add) { + if (has_mask) valid_pattern = false; + has_mask = true; + } + break; + } + case op_kind::dnnl_mask: { + if (has_mask) valid_pattern = false; + has_mask = true; + break; + } + case op_kind::dnnl_softmax: { + if (has_softmax) valid_pattern = false; + has_softmax = true; + break; + } + default: valid_pattern = false; + } + + if (!valid_pattern) break; + + auto out_val = walker->get_output_value(0); + if (out_val->get_consumers().size() != 1) break; + walker = out_val->get_consumers()[0].get_op().shared_from_this(); + } + + if (valid_pattern && finished) { + candidates = pattern_ops; + break; + } + } + + if (candidates.empty()) return status::success; + + subgraph_rewriter_t rewriter(sg); + op_ptr sdpa_op = std::make_shared(op_kind::dnnl_sdpa); + sdpa_op->set_attr(op_attr::with_scale, false); + sdpa_op->set_attr(op_attr::with_mask, false); + sdpa_op->set_attr(op_attr::with_causal, false); + + auto query_val = candidates[0]->get_input_value(0); + query_val->remove_consumer(*candidates[0], 0); + sdpa_op->connect_input(0, query_val); + + auto key_val = candidates[0]->get_input_value(1); + key_val->remove_consumer(*candidates[0], 1); + sdpa_op->connect_input(1, key_val); + + auto value_val = candidates.back()->get_input_value(1); + value_val->remove_consumer(*candidates.back(), 1); + sdpa_op->connect_input(2, value_val); + + size_t input_idx = 3; + for (size_t i = 1; i < candidates.size(); ++i) { + auto op = candidates[i]; + if (op->get_kind() == op_kind::dnnl_binary) { + auto alg = static_cast( + op->get_attr(op_attr::alg_kind)); + // handle scale + if (alg == dnnl::algorithm::binary_mul + || alg == dnnl::algorithm::binary_div) { + auto scale_val = op->get_input_value(1); + scale_val->remove_consumer(*op, 1); + sdpa_op->connect_input(input_idx++, scale_val); + sdpa_op->set_attr(op_attr::with_scale, true); + sdpa_op->set_attr(op_attr::is_invert_scale, + (alg == dnnl::algorithm::binary_div)); + } + // handle explicit mask + else if (alg == dnnl::algorithm::binary_add) { + auto mask_val = op->get_input_value(1); + mask_val->remove_consumer(*op, 1); + sdpa_op->connect_input(input_idx++, mask_val); + sdpa_op->set_attr(op_attr::with_mask, true); + } + } + // handle implicit dnnl_mask + else if (op->get_kind() == op_kind::dnnl_mask) { + sdpa_op->set_attr(op_attr::with_causal, true); + } + } + + auto final_output = candidates.back()->get_output_value(0); + final_output->set_producer(*sdpa_op); + sdpa_op->add_output(final_output); + + insert_empty_scratchpad(sdpa_op); + + for (auto &op : candidates) { + rewriter.to_remove(op); + } + rewriter.to_insert(sdpa_op); + rewriter.run(); + return status::success; +} + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/passes/transform.hpp b/src/graph/backend/dnnl/passes/transform.hpp index e0b8793b385..7b2cc197222 100644 --- a/src/graph/backend/dnnl/passes/transform.hpp +++ b/src/graph/backend/dnnl/passes/transform.hpp @@ -204,9 +204,10 @@ impl::status_t lift_up_weight_reshape_for_depthwiseconv( // This pass will compute matmul with the src layout of transpose before matmul impl::status_t fuse_src_transpose_to_matmul(std::shared_ptr &sg); -// This pass will compute matmul with the dst layout of following transpose if -// the operator after transpose need a dense layout -impl::status_t fuse_dst_transpose_to_matmul(std::shared_ptr &sg); +// This pass will compute matmul/sdpa with the dst layout of following transpose +// if the operator after transpose need a dense layout +impl::status_t fuse_dst_transpose_to_predecessor( + std::shared_ptr &sg); // This pass will fuse all the reshape to its lead op for GQA. impl::status_t fuse_reshape_for_gqa(std::shared_ptr &sg); @@ -283,6 +284,9 @@ impl::status_t replace_select_values(std::shared_ptr &sg); /// | status_t fuse_implicit_causal_mask(std::shared_ptr &sg); +/// This pass will transform the sdpa subgraph into a dnnl_sdpa op. +status_t fuse_sdpa(std::shared_ptr &sg); + } // namespace dnnl_impl } // namespace graph } // namespace impl