Skip to content

Commit b5d090f

Browse files
committed
graph: dnnl: add internal sdpa op
1 parent dac23cd commit b5d090f

10 files changed

+331
-1
lines changed

src/graph/backend/dnnl/dnnl_op_def.hpp

+25
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,31 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_mask, 1,
11341134
.SET_EXECUTABLE_CREATOR(executable_creator<memory_reparser_t>)
11351135
.SET_ARG_INDICES_GETTER(memory_reparser_t))
11361136

1137+
DNNL_GRAPH_OP_SCHEMA(dnnl_sdpa, 1,
1138+
op_schema_t()
1139+
.set_inputs_option(op_schema_t::param_num_option::variadic)
1140+
.set_num_inputs(std::set<size_t>({3, 32}))
1141+
.set_num_outputs(2)
1142+
.set_input(0, "query")
1143+
.set_input(1, "key")
1144+
.set_input(2, "value")
1145+
.set_input(3, "scale") // optional
1146+
.set_input(4, "mask") // optional
1147+
.set_output(0, "output")
1148+
.set_output(1, "scratchpad")
1149+
.set_attr(op_attr::with_scale, true, attribute_kind::b)
1150+
.set_attr(op_attr::is_invert_scale, false, attribute_kind::b,
1151+
false)
1152+
.set_attr(op_attr::with_mask, true, attribute_kind::b)
1153+
// with_causal attribute support top-left mask type only
1154+
.set_attr(op_attr::with_causal, true, attribute_kind::b)
1155+
.set_attr(op_attr::fusion_info_key, false, attribute_kind::i,
1156+
(int64_t)-1)
1157+
.set_shape_inference_function(infer_dnnl_sdpa_output_shape)
1158+
.SET_LAYOUT_PROPAGATOR(layout_propagator_for_sdpa)
1159+
.SET_EXECUTABLE_CREATOR(executable_creator<sdpa_executable_t>)
1160+
.SET_ARG_INDICES_GETTER(sdpa_executable_t))
1161+
11371162
} // namespace dnnl_impl
11381163
} // namespace graph
11391164
} // namespace impl

src/graph/backend/dnnl/dnnl_opset.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class dnnl_opset_t {
9797
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_layernorm, 1)>());
9898
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_reorder, 1)>());
9999
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_groupnorm, 1)>());
100+
fn(get_op_schema<DNNL_GRAPH_OP_SCHEMA_CLASS_NAME(dnnl_sdpa, 1)>());
100101
}
101102
};
102103

src/graph/backend/dnnl/dnnl_shape_infer.cpp

+57
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,63 @@ status_t infer_dnnl_binary_output_shape(op_t *n,
545545
}
546546
}
547547

548+
status_t infer_dnnl_sdpa_output_shape(op_t *n,
549+
std::vector<logical_tensor_t *> &inputs,
550+
std::vector<logical_tensor_t *> &outputs) {
551+
// [batch_size, num_heads_q, seq_len_q, head_size_qk]
552+
auto query = logical_tensor_wrapper_t(inputs[0]);
553+
// [batch_size, num_heads_q, head_size_qk, seq_len_kv,]
554+
auto key = logical_tensor_wrapper_t(inputs[1]);
555+
// [batch_size, num_heads_v, seq_len_kv, head_size_v]
556+
auto value = logical_tensor_wrapper_t(inputs[2]);
557+
// [batch_size, num_heads_q, seq_len_q, head_size_v]
558+
auto out0 = logical_tensor_wrapper_t(outputs[0]);
559+
560+
dims query_dims = query.vdims();
561+
dims key_dims = key.vdims();
562+
dims value_dims = value.vdims();
563+
564+
VCHECK_INVALID_SHAPE((query_dims.size() == key_dims.size()
565+
&& key_dims.size() == value_dims.size()),
566+
"%s, all input dims should match each other. input0 dims: %s, "
567+
"input1 dims: %s, input2 dims: %s ",
568+
op_t::kind2str(n->get_kind()).c_str(), dims2str(query_dims).c_str(),
569+
dims2str(key_dims).c_str(), dims2str(value_dims).c_str());
570+
571+
VCHECK_INVALID_SHAPE((query_dims.size() == 4),
572+
"%s, only support 4D input for all q/k/v. input0 dimension: %s, "
573+
"input1 dimension: %s, input2 dimension: %s ",
574+
op_t::kind2str(n->get_kind()).c_str(),
575+
std::to_string(query_dims.size()).c_str(),
576+
std::to_string(key_dims.size()).c_str(),
577+
std::to_string(value_dims.size()).c_str());
578+
579+
VCHECK_INVALID_SHAPE((query_dims[3] == key_dims[2]),
580+
"%s, query head size should be match with key head size. query "
581+
"dims: %s, Key dims: %s",
582+
op_t::kind2str(n->get_kind()).c_str(), dims2str(query_dims).c_str(),
583+
dims2str(key_dims).c_str());
584+
585+
VCHECK_INVALID_SHAPE((key_dims[3] == value_dims[2]),
586+
"%s, key sequence length should be match with value sequence "
587+
"length. key dims: %s, value dims: %s ",
588+
op_t::kind2str(n->get_kind()).c_str(), dims2str(key_dims).c_str(),
589+
dims2str(value_dims).c_str());
590+
591+
dims inferred_output_shape;
592+
inferred_output_shape
593+
= {query_dims[0], query_dims[1], query_dims[2], value_dims[3]};
594+
595+
if (out0.ndims() != -1) {
596+
VCHECK_INVALID_SHAPE(validate(inferred_output_shape, out0.vdims()),
597+
"%s, inferred out shape and output shape are not compatible",
598+
op_t::kind2str(n->get_kind()).c_str());
599+
}
600+
601+
set_shape_and_strides(*outputs[0], inferred_output_shape);
602+
return status::success;
603+
}
604+
548605
} // namespace dnnl_impl
549606
} // namespace graph
550607
} // namespace impl

src/graph/backend/dnnl/dnnl_shape_infer.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ status_t infer_binary_select_output_shape(op_t *n,
107107
std::vector<logical_tensor_t *> &inputs,
108108
std::vector<logical_tensor_t *> &outputs);
109109

110+
status_t infer_dnnl_sdpa_output_shape(op_t *n,
111+
std::vector<logical_tensor_t *> &inputs,
112+
std::vector<logical_tensor_t *> &outputs);
113+
110114
} // namespace dnnl_impl
111115
} // namespace graph
112116
} // namespace impl

src/graph/backend/dnnl/internal_attrs.hpp

+8
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ const op_attr_t with_runtime_dst_zps = 0x1000c;
4545
const op_attr_t is_bias_add = 0x1000d;
4646
const op_attr_t with_sum = 0x1000e;
4747
const op_attr_t keep_dst_layout = 0x1000f;
48+
const op_attr_t with_scale = 0x10010;
49+
const op_attr_t is_invert_scale = 0x10011;
50+
const op_attr_t with_causal = 0x10012;
51+
const op_attr_t with_mask = 0x10013;
4852

4953
// int64_t
5054
const op_attr_t alg_kind = 0x10100;
@@ -86,6 +90,10 @@ static inline std::string internal_attr2str(op_attr_t attr) {
8690
CASE(is_bias_add);
8791
CASE(with_sum);
8892
CASE(keep_dst_layout);
93+
CASE(with_scale);
94+
CASE(is_invert_scale);
95+
CASE(with_causal);
96+
CASE(with_mask);
8997
CASE(alg_kind);
9098
CASE(fusion_info_key);
9199
CASE(axis_row);

src/graph/backend/dnnl/internal_ops.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ namespace op_kind {
7979
X(dnnl_convtranspose_bwd_weights, Dnnl_convtranspose_bwd_weights) \
8080
X(dnnl_groupnorm, Dnnl_groupnorm) \
8181
X(dnnl_gen_index, Dnnl_gen_index) \
82-
X(dnnl_mask, Dnnl_mask)
82+
X(dnnl_mask, Dnnl_mask) \
83+
X(dnnl_sdpa, Dnnl_sdpa)
8384

8485
enum kind_t {
8586
kDNNL_INTERNAL_OP_STARTER = 0x1234,

src/graph/backend/dnnl/layout_propagator.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,35 @@ status_t layout_propagator_for_mask(std::shared_ptr<op_t> &op,
15681568
return status;
15691569
}
15701570

1571+
status_t layout_propagator_for_sdpa(std::shared_ptr<op_t> &op,
1572+
const dnnl::engine &p_engine, fusion_info_mgr_t &mgr,
1573+
pd_cache_t &pd_cache, subgraph_rewriter_t &rewriter) {
1574+
UNUSED(p_engine);
1575+
UNUSED(mgr);
1576+
UNUSED(pd_cache);
1577+
UNUSED(rewriter);
1578+
1579+
value_ptr dst_val = op->get_output_value(0);
1580+
const logical_tensor_t &out_lt = dst_val->get_logical_tensor();
1581+
1582+
dnnl::memory::desc expected_md;
1583+
// Set default output layout format for sdpa as acbd
1584+
if (ltw(out_lt).is_any()) {
1585+
expected_md = {ltw(out_lt).vdims(),
1586+
static_cast<dnnl::memory::data_type>(ltw(out_lt).data_type()),
1587+
dnnl::memory::format_tag::acbd};
1588+
} else {
1589+
expected_md = make_dnnl_memory_desc(out_lt);
1590+
}
1591+
status_t status = fill_layout_info(dst_val, expected_md);
1592+
1593+
// fill scratchpads dimensions and data type to scratchpad value_t
1594+
value_ptr scratchpad_val = op->get_output_value(1);
1595+
const memory::desc scratchpad_desc;
1596+
status = fill_layout_info(scratchpad_val, scratchpad_desc);
1597+
return status;
1598+
}
1599+
15711600
} // namespace dnnl_impl
15721601
} // namespace graph
15731602
} // namespace impl

src/graph/backend/dnnl/layout_propagator.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ DECLARE_LAYOUT_PROPAGATOR(add_zps);
9393
DECLARE_LAYOUT_PROPAGATOR(groupnorm);
9494
DECLARE_LAYOUT_PROPAGATOR(gen_index);
9595
DECLARE_LAYOUT_PROPAGATOR(mask);
96+
DECLARE_LAYOUT_PROPAGATOR(sdpa);
9697

9798
#undef DECLARE_LAYOUT_PROPAGATOR
9899

src/graph/backend/dnnl/op_executable.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -2405,6 +2405,29 @@ arg_indices_t genindex_executable_t::get_arg_indices(
24052405
return arg_indices;
24062406
}
24072407

2408+
arg_indices_t sdpa_executable_t::get_arg_indices(
2409+
const op_t *op, fusion_info_mgr_t &mgr) {
2410+
UNUSED(mgr);
2411+
2412+
arg_indices_t arg_indices;
2413+
// add input args
2414+
size_t index = 0;
2415+
arg_indices.insert({DNNL_ARG_QUERIES, indices_t {input, index++}});
2416+
arg_indices.insert({DNNL_ARG_KEYS, indices_t {input, index++}});
2417+
arg_indices.insert({DNNL_ARG_VALUES, indices_t {input, index++}});
2418+
if (op->get_attr<bool>(dnnl::impl::graph::dnnl_impl::op_attr::with_scale)) {
2419+
arg_indices.insert({DNNL_ARG_SCALE, indices_t {input, index++}});
2420+
}
2421+
if (op->get_attr<bool>(dnnl::impl::graph::dnnl_impl::op_attr::with_mask)) {
2422+
arg_indices.insert({DNNL_ARG_ATTN_MASK, indices_t {input, index++}});
2423+
}
2424+
2425+
// add output args
2426+
arg_indices.insert({DNNL_ARG_DST, indices_t {output, 0}});
2427+
arg_indices.insert({DNNL_ARG_SCRATCHPAD, indices_t {output, 1}});
2428+
return arg_indices;
2429+
}
2430+
24082431
} // namespace dnnl_impl
24092432
} // namespace graph
24102433
} // namespace impl

0 commit comments

Comments
 (0)