Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: introduce internal dnnl_sdpa op #2930

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/graph/backend/dnnl/dnnl_op_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,32 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_mask, 1,
.SET_EXECUTABLE_CREATOR(executable_creator<memory_reparser_t>)
.SET_ARG_INDICES_GETTER(memory_reparser_t))

// The data types of query/key/value/mask/output must be consistent, and only
// f16/bf16 are supported. The data type of scale must be consistent with other
// input and output data types or fp32.
DNNL_GRAPH_OP_SCHEMA(dnnl_sdpa, 1,
op_schema_t()
.set_inputs_option(op_schema_t::param_num_option::variadic)
.set_num_inputs(std::set<size_t>({3, 32}))
.set_num_outputs(2)
.set_input(0, "query")
.set_input(1, "key")
.set_input(2, "value")
.set_input(3, "scale") // optional
.set_input(4, "mask") // optional
.set_output(0, "output")
.set_output(1, "scratchpad")
.set_attr(op_attr::with_scale, true, attribute_kind::b)
.set_attr(op_attr::is_invert_scale, false, attribute_kind::b,
false)
.set_attr(op_attr::with_mask, true, attribute_kind::b)
// with_causal attribute support top-left mask type only
.set_attr(op_attr::with_causal, true, attribute_kind::b)
.set_shape_inference_function(infer_dnnl_sdpa_output_shape)
.SET_LAYOUT_PROPAGATOR(layout_propagator_for_sdpa)
.SET_EXECUTABLE_CREATOR(executable_creator<sdpa_executable_t>)
.SET_ARG_INDICES_GETTER(sdpa_executable_t))

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
Expand Down
1 change: 1 addition & 0 deletions src/graph/backend/dnnl/dnnl_opset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class dnnl_opset_t {
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_layernorm, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_reorder, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_groupnorm, 1)>());
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_sdpa, 1)>());
}
};

Expand Down
57 changes: 57 additions & 0 deletions src/graph/backend/dnnl/dnnl_shape_infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,63 @@ status_t infer_dnnl_binary_output_shape(op_t *n,
}
}

status_t infer_dnnl_sdpa_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs) {
// [batch_size, num_heads_q, seq_len_q, head_size_qk]
auto query = logical_tensor_wrapper_t(inputs[0]);
// [batch_size, num_heads_q, head_size_qk, seq_len_kv,]
auto key = logical_tensor_wrapper_t(inputs[1]);
// [batch_size, num_heads_v, seq_len_kv, head_size_v]
auto value = logical_tensor_wrapper_t(inputs[2]);
// [batch_size, num_heads_q, seq_len_q, head_size_v]
auto out0 = logical_tensor_wrapper_t(outputs[0]);

dims query_dims = query.vdims();
dims key_dims = key.vdims();
dims value_dims = value.vdims();

VCHECK_INVALID_SHAPE((query_dims.size() == key_dims.size()
&& key_dims.size() == value_dims.size()),
"%s, all input dims should match each other. input0 dims: %s, "
"input1 dims: %s, input2 dims: %s ",
op_t::kind2str(n->get_kind()).c_str(), dims2str(query_dims).c_str(),
dims2str(key_dims).c_str(), dims2str(value_dims).c_str());

VCHECK_INVALID_SHAPE((query_dims.size() == 4),
"%s, only support 4D input for all q/k/v. input0 dimension: %s, "
"input1 dimension: %s, input2 dimension: %s ",
op_t::kind2str(n->get_kind()).c_str(),
std::to_string(query_dims.size()).c_str(),
std::to_string(key_dims.size()).c_str(),
std::to_string(value_dims.size()).c_str());

VCHECK_INVALID_SHAPE((query_dims[3] == key_dims[2]),
"%s, query head size should be match with key head size. query "
"dims: %s, Key dims: %s",
op_t::kind2str(n->get_kind()).c_str(), dims2str(query_dims).c_str(),
dims2str(key_dims).c_str());

VCHECK_INVALID_SHAPE((key_dims[3] == value_dims[2]),
"%s, key sequence length should be match with value sequence "
"length. key dims: %s, value dims: %s ",
op_t::kind2str(n->get_kind()).c_str(), dims2str(key_dims).c_str(),
dims2str(value_dims).c_str());

dims inferred_output_shape;
inferred_output_shape
= {query_dims[0], query_dims[1], query_dims[2], value_dims[3]};

if (out0.ndims() != -1) {
VCHECK_INVALID_SHAPE(validate(inferred_output_shape, out0.vdims()),
"%s, inferred out shape and output shape are not compatible",
op_t::kind2str(n->get_kind()).c_str());
}

set_shape_and_strides(*outputs[0], inferred_output_shape);
return status::success;
}

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
Expand Down
4 changes: 4 additions & 0 deletions src/graph/backend/dnnl/dnnl_shape_infer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ status_t infer_binary_select_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);

status_t infer_dnnl_sdpa_output_shape(op_t *n,
std::vector<logical_tensor_t *> &inputs,
std::vector<logical_tensor_t *> &outputs);

} // namespace dnnl_impl
} // namespace graph
} // namespace impl
Expand Down
8 changes: 8 additions & 0 deletions src/graph/backend/dnnl/internal_attrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ const op_attr_t with_runtime_dst_zps = 0x1000c;
const op_attr_t is_bias_add = 0x1000d;
const op_attr_t with_sum = 0x1000e;
const op_attr_t keep_dst_layout = 0x1000f;
const op_attr_t with_scale = 0x10010;
const op_attr_t is_invert_scale = 0x10011;
const op_attr_t with_causal = 0x10012;
const op_attr_t with_mask = 0x10013;

// int64_t
const op_attr_t alg_kind = 0x10100;
Expand Down Expand Up @@ -86,6 +90,10 @@ static inline std::string internal_attr2str(op_attr_t attr) {
CASE(is_bias_add);
CASE(with_sum);
CASE(keep_dst_layout);
CASE(with_scale);
CASE(is_invert_scale);
CASE(with_causal);
CASE(with_mask);
CASE(alg_kind);
CASE(fusion_info_key);
CASE(axis_row);
Expand Down
3 changes: 2 additions & 1 deletion src/graph/backend/dnnl/internal_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ namespace op_kind {
X(dnnl_convtranspose_bwd_weights, Dnnl_convtranspose_bwd_weights) \
X(dnnl_groupnorm, Dnnl_groupnorm) \
X(dnnl_gen_index, Dnnl_gen_index) \
X(dnnl_mask, Dnnl_mask)
X(dnnl_mask, Dnnl_mask) \
X(dnnl_sdpa, Dnnl_sdpa)

enum kind_t {
kDNNL_INTERNAL_OP_STARTER = 0x1234,
Expand Down
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/large_partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void larger_partition_kernel_t::setup_pipeline_stage2(pass_pipeline_t &pipeline,
}
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_src_transpose_to_matmul);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
BACKEND_DNNL_ADD_PASS(pipeline, common_reorder_elimination);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_adjacent_reorders);
Expand Down
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ status_t matmul_t<quantized>::compile_impl(const dnnl_partition_impl_t *part,
}

BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);

BACKEND_DNNL_ADD_PASS(pipeline, fuse_adjacent_reorders);
Expand Down
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/mqa_decomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ status_t mqa_decomp_kernel_t<quantized, dt>::compile_impl(
BACKEND_DNNL_ADD_PASS(pipeline, remove_quant_data_with_no_effect);
}
pipeline.reset_visualize_arg(true, false);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);

// Run the added passes
Expand Down
13 changes: 11 additions & 2 deletions src/graph/backend/dnnl/kernels/sdp.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
* Copyright 2024-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,6 +27,7 @@
#include "graph/backend/dnnl/kernels/large_partition.hpp"
#include "graph/backend/dnnl/kernels/sdp_decomp.hpp"
#include "graph/backend/dnnl/kernels/sdp_primitive.hpp"
#include "graph/backend/dnnl/kernels/sdp_primitive_v1.hpp"

#include "graph/backend/dnnl/dnnl_partition_impl.hpp"

Expand Down Expand Up @@ -65,7 +66,15 @@ struct sdp_base_t : public kernel_base_t {

status_t ret = status::unimplemented;

if (enable_ukernel) {
// SDPA Ukernel v1 with fused internal sdpa solution. Support fload sdpa
// only.
// TODO(GX): Support quantized sdpa and merge with sdp_primitive_kernel_t.
if (enable_ukernel && !quantized) {
kernel = std::make_shared<sdp_primitive_v1_kernel_t>();
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
}

if (ret != status::success && enable_ukernel) {
kernel = std::make_shared<sdp_primitive_kernel_t<quantized>>();
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
}
Expand Down
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/sdp_decomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ status_t sdp_decomp_kernel_t<quantized, dt>::compile_impl(
BACKEND_DNNL_ADD_PASS(pipeline, remove_quant_data_with_no_effect);
}
pipeline.reset_visualize_arg(true, false);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);

// Run the added passes
Expand Down
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/sdp_primitive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ status_t sdp_primitive_kernel_t<quantized>::compile_impl(

pipeline.reset_visualize_arg(true, false);
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);

// bind the memory for each op
Expand Down
20 changes: 19 additions & 1 deletion src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,29 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,

status_t sdp_primitive_config_t::initial_check(
const std::shared_ptr<subgraph_t> &sg,
const std::vector<logical_tensor_t> &inputs) {
const std::vector<logical_tensor_t> &inputs, bool v1_kernel) {
// At least 3 inputs: Q, K, V
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
"At least 3 inputs are required");

// Ukernel doesn't support f32 datatype now
VCHECK_SDP_PRIMITIVE(inputs[0].data_type != dnnl_data_type_t::dnnl_f32,
status::invalid_arguments,
"SDPA ukernel doesn't support f32 datatype now");

// Note: sdpa_primitive_v1 kernel currently don't support legacy GQA pattern.
if (v1_kernel) {
for (auto &cur_op : sg->get_ops()) {
if (cur_op->get_kind() == graph::op_kind::StaticReshape) {
auto in = cur_op->get_input_value(0)->get_logical_tensor();
auto out = cur_op->get_output_value(0)->get_logical_tensor();
if (ltw(in).ndims() == 5 || ltw(out).ndims() == 5) {
return status::unimplemented;
}
}
}
}

// step1(pattern check): Not support sdpa variants with select as mask
// We already have a pattern matcher to ensure that the sdpa patterns
// dispatch to here are knows ones, and we have quant check in sdpa base
Expand Down
3 changes: 2 additions & 1 deletion src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ struct sdp_primitive_config_t {
// 2. only support fp16 data type
// 3. only support 4-dims tensor
status_t initial_check(const std::shared_ptr<subgraph_t> &sg,
const std::vector<logical_tensor_t> &inputs);
const std::vector<logical_tensor_t> &inputs,
bool v1_kernel = false);

// Initialize parameters and primitive.
status_t init(std::shared_ptr<subgraph_t> &sg, const dnnl::engine &p_engine,
Expand Down
Loading
Loading