Skip to content

Commit e2cd4db

Browse files
committed
graph: dnnl: code polish and address left todo and skip legacy GQA
1 parent 35098b4 commit e2cd4db

13 files changed

+129
-36
lines changed

src/graph/backend/dnnl/dnnl_shape_infer.cpp

+37-3
Original file line numberDiff line numberDiff line change
@@ -545,13 +545,18 @@ status_t infer_dnnl_binary_output_shape(op_t *n,
545545
}
546546
}
547547

548-
//TODO(GX): revisit this function to correct logic, check if shape is given
549548
status_t infer_dnnl_sdpa_output_shape(op_t *n,
550549
std::vector<logical_tensor_t *> &inputs,
551550
std::vector<logical_tensor_t *> &outputs) {
551+
// [batch_size, num_heads_q, seq_len_q, head_size_qk]
552552
auto query = logical_tensor_wrapper_t(inputs[0]);
553+
// [batch_size, num_heads_q, head_size_qk, seq_len_kv,]
553554
auto key = logical_tensor_wrapper_t(inputs[1]);
554-
auto value = logical_tensor_wrapper_t(inputs[1]);
555+
// [batch_size, num_heads_v, seq_len_kv, head_size_v]
556+
auto value = logical_tensor_wrapper_t(inputs[2]);
557+
// [batch_size, num_heads_q, seq_len_q, head_size_v]
558+
auto out0 = logical_tensor_wrapper_t(outputs[0]);
559+
555560
dims query_dims = query.vdims();
556561
dims key_dims = key.vdims();
557562
dims value_dims = value.vdims();
@@ -563,7 +568,36 @@ status_t infer_dnnl_sdpa_output_shape(op_t *n,
563568
op_t::kind2str(n->get_kind()).c_str(), dims2str(query_dims).c_str(),
564569
dims2str(key_dims).c_str(), dims2str(value_dims).c_str());
565570

566-
dims inferred_output_shape = query_dims;
571+
VCHECK_INVALID_SHAPE((query_dims.size() == 4),
572+
"%s, only support 4D input for all Q/K/V. input0 dimension: %s, "
573+
"input1 dimension: %s, input2 dimension: %s ",
574+
op_t::kind2str(n->get_kind()).c_str(),
575+
std::to_string(query_dims.size()).c_str(),
576+
std::to_string(key_dims.size()).c_str(),
577+
std::to_string(value_dims.size()).c_str());
578+
579+
VCHECK_INVALID_SHAPE((query_dims[3] == key_dims[2]),
580+
"%s, query head size should be match with key head size. query "
581+
"dims: %s, Key dims: %s",
582+
op_t::kind2str(n->get_kind()).c_str(), dims2str(query_dims).c_str(),
583+
dims2str(key_dims).c_str());
584+
585+
VCHECK_INVALID_SHAPE((key_dims[3] == value_dims[2]),
586+
"%s, key sequence length should be match with value sequence "
587+
"length. key dims: %s, value dims: %s ",
588+
op_t::kind2str(n->get_kind()).c_str(), dims2str(key_dims).c_str(),
589+
dims2str(value_dims).c_str());
590+
591+
dims inferred_output_shape;
592+
inferred_output_shape
593+
= {query_dims[0], query_dims[1], query_dims[2], value_dims[3]};
594+
595+
if (out0.ndims() != -1) {
596+
VCHECK_INVALID_SHAPE(validate(inferred_output_shape, out0.vdims()),
597+
"%s, inferred out shape and output shape are not compatible",
598+
op_t::kind2str(n->get_kind()).c_str());
599+
}
600+
567601
set_shape_and_strides(*outputs[0], inferred_output_shape);
568602
return status::success;
569603
}

src/graph/backend/dnnl/kernels/large_partition.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ void larger_partition_kernel_t::setup_pipeline_stage2(pass_pipeline_t &pipeline,
142142
}
143143
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
144144
BACKEND_DNNL_ADD_PASS(pipeline, fuse_src_transpose_to_matmul);
145-
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
145+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
146146
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
147147
BACKEND_DNNL_ADD_PASS(pipeline, common_reorder_elimination);
148148
BACKEND_DNNL_ADD_PASS(pipeline, fuse_adjacent_reorders);

src/graph/backend/dnnl/kernels/matmul.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ status_t matmul_t<quantized>::compile_impl(const dnnl_partition_impl_t *part,
110110
}
111111

112112
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
113-
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
113+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
114114
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
115115

116116
BACKEND_DNNL_ADD_PASS(pipeline, fuse_adjacent_reorders);

src/graph/backend/dnnl/kernels/mqa_decomp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ status_t mqa_decomp_kernel_t<quantized, dt>::compile_impl(
8787
BACKEND_DNNL_ADD_PASS(pipeline, remove_quant_data_with_no_effect);
8888
}
8989
pipeline.reset_visualize_arg(true, false);
90-
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
90+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
9191
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
9292

9393
// Run the added passes

src/graph/backend/dnnl/kernels/sdp_decomp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ status_t sdp_decomp_kernel_t<quantized, dt>::compile_impl(
8686
BACKEND_DNNL_ADD_PASS(pipeline, remove_quant_data_with_no_effect);
8787
}
8888
pipeline.reset_visualize_arg(true, false);
89-
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
89+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
9090
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
9191

9292
// Run the added passes

src/graph/backend/dnnl/kernels/sdp_primitive.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ status_t sdp_primitive_kernel_t<quantized>::compile_impl(
9292

9393
pipeline.reset_visualize_arg(true, false);
9494
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
95-
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
95+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
9696
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
9797

9898
// bind the memory for each op

src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp

+22-5
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,28 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
166166

167167
status_t sdp_primitive_config_t::initial_check(
168168
const std::shared_ptr<subgraph_t> &sg,
169-
const std::vector<logical_tensor_t> &inputs) {
169+
const std::vector<logical_tensor_t> &inputs, bool v1_kenrel) {
170170
// At least 3 inputs: Q, K, V
171171
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
172172
"At least 3 inputs are required");
173173

174+
VCHECK_SDP_PRIMITIVE(inputs[0].data_type != dnnl_data_type_t::dnnl_f32,
175+
status::invalid_arguments,
176+
"SDPA ukernel doesn't support f32 datatype now");
177+
178+
// Note: sdpa_primitive_v1 kenrel currently don't support legacy GQA pattern.
179+
if (v1_kenrel) {
180+
for (auto &cur_op : sg->get_ops()) {
181+
if (cur_op->get_kind() == graph::op_kind::StaticReshape) {
182+
auto in = cur_op->get_input_value(0)->get_logical_tensor();
183+
auto out = cur_op->get_output_value(0)->get_logical_tensor();
184+
if (ltw(in).ndims() == 5 || ltw(out).ndims() == 5) {
185+
return status::unimplemented;
186+
}
187+
}
188+
}
189+
}
190+
174191
// step1(pattern check): Not support sdpa variants with select as mask
175192
// We already have a pattern matcher to ensure that the sdpa patterns
176193
// dispatch to here are knows ones, and we have quant check in sdpa base
@@ -268,10 +285,10 @@ status_t sdp_primitive_config_t::initial_check(
268285

269286
VCHECK_SDP_PRIMITIVE(q_id != -1 && k_id != -1 && v_id != -1,
270287
status::unimplemented, "Q, K, V are not found");
271-
VCHECK_SDP_PRIMITIVE(ltw(inputs[q_id]).vdims().size() == 4
272-
&& ltw(inputs[k_id]).vdims().size() == 4
273-
&& ltw(inputs[v_id]).vdims().size() == 4,
274-
status::unimplemented, "Q, K, V should be 4-dims");
288+
// VCHECK_SDP_PRIMITIVE(ltw(inputs[q_id]).vdims().size() == 4
289+
// && ltw(inputs[k_id]).vdims().size() == 4
290+
// && ltw(inputs[v_id]).vdims().size() == 4,
291+
// status::unimplemented, "Q, K, V should be 4-dims");
275292

276293
// sdp_primitive only supports single scale value.
277294
if (scale) {

src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ struct sdp_primitive_config_t {
8282
// 2. only support fp16 data type
8383
// 3. only support 4-dims tensor
8484
status_t initial_check(const std::shared_ptr<subgraph_t> &sg,
85-
const std::vector<logical_tensor_t> &inputs);
85+
const std::vector<logical_tensor_t> &inputs,
86+
bool v1_kenrel = false);
8687

8788
// Initialize parameters and primitive.
8889
status_t init(std::shared_ptr<subgraph_t> &sg, const dnnl::engine &p_engine,

src/graph/backend/dnnl/kernels/sdp_primitive_v1.cpp

+11-7
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ status_t sdp_primitive_v1_kernel_t<quantized>::compile_impl(
6161
p_engine_, part->get_fpmath_mode(), false, true);
6262
CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs));
6363

64-
CHECK(cfg_.initial_check(subgraph_, inputs));
64+
CHECK(cfg_.initial_check(subgraph_, inputs, true));
6565

6666
subgraph_visualizer_t vis(part->id(), [this](const value_t *val) {
6767
return this->memory_planner_.get_memory_info(val);
@@ -76,10 +76,8 @@ status_t sdp_primitive_v1_kernel_t<quantized>::compile_impl(
7676
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
7777
BACKEND_DNNL_ADD_PASS(pipeline, fuse_src_transpose_to_matmul);
7878
BACKEND_DNNL_ADD_PASS(pipeline, fuse_sdpa);
79+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
7980
BACKEND_DNNL_ADD_PASS(pipeline, insert_reshape_for_sdpa);
80-
81-
// TODO(GX):add fuse dst transpose to sdpa
82-
// BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
8381
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
8482

8583
// bind the memory for each op`
@@ -145,7 +143,9 @@ status_t sdp_primitive_v1_kernel_t<quantized>::execute_impl(
145143

146144
// Micro kernel doesn't use scratchpad memory, here we force-set size as
147145
// zero to avoid redundant memory allocation and deallocation.
148-
temporary_scratchpad_t scratchpad(memory_planner_.total_internal_temporary_size(), p_engine_, *g_alloc_);
146+
temporary_scratchpad_t scratchpad(
147+
memory_planner_.total_internal_temporary_size(), p_engine_,
148+
*g_alloc_);
149149
prepare_args_set(res, inputs, outputs, scratchpad);
150150

151151
for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
@@ -177,7 +177,9 @@ status_t sdp_primitive_v1_kernel_t<quantized>::sycl_execute_impl(
177177

178178
// Micro kernel doesn't use scratchpad memory, here we force-set size as
179179
// zero to avoid redundant memory allocation and deallocation.
180-
temporary_scratchpad_t scratchpad(memory_planner_.total_internal_temporary_size(), p_engine_, *g_alloc_);
180+
temporary_scratchpad_t scratchpad(
181+
memory_planner_.total_internal_temporary_size(), p_engine_,
182+
*g_alloc_);
181183
prepare_args_set(res, inputs, outputs, scratchpad);
182184

183185
for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
@@ -215,7 +217,9 @@ status_t sdp_primitive_v1_kernel_t<quantized>::ocl_execute_impl(
215217

216218
// Micro kernel doesn't use scratchpad memory, here we force-set size as
217219
// zero to avoid redundant memory allocation and deallocation.
218-
temporary_scratchpad_t scratchpad(memory_planner_.total_internal_temporary_size(), p_engine_, *g_alloc_);
220+
temporary_scratchpad_t scratchpad(
221+
memory_planner_.total_internal_temporary_size(), p_engine_,
222+
*g_alloc_);
219223
prepare_args_set(res, inputs, outputs, scratchpad);
220224

221225
for (size_t i = 0; i < subgraph_->execs_.size(); i++) {

src/graph/backend/dnnl/layout_propagator.cpp

+32-4
Original file line numberDiff line numberDiff line change
@@ -1576,11 +1576,39 @@ status_t layout_propagator_for_sdpa(std::shared_ptr<op_t> &op,
15761576
UNUSED(mgr);
15771577
UNUSED(pd_cache);
15781578
UNUSED(rewriter);
1579-
auto dst_md = make_dnnl_memory_desc(
1580-
op->get_output_value(0)->get_logical_tensor());
1579+
15811580
value_ptr dst_val = op->get_output_value(0);
1582-
dst_val->set_strides(get_dense_strides(dst_md.get_dims()));
1583-
status_t status = fill_layout_info(dst_val, dst_md);
1581+
const logical_tensor_t &out_lt = dst_val->get_logical_tensor();
1582+
dnnl::memory::desc expected_md;
1583+
1584+
if (ltw(out_lt).is_any()) {
1585+
// For GQA, we need to check the layout of the dnnl_reshape output
1586+
// following dnnl_sdpa, which is given by the user.
1587+
if (!dst_val->get_consumers().empty()) {
1588+
const auto &consumer_op = dst_val->get_consumers()[0].get_op();
1589+
const auto &consumer_out = ltw(
1590+
consumer_op.get_output_value(0)->get_logical_tensor());
1591+
if (consumer_op.get_kind() == op_kind::dnnl_reshape
1592+
&& consumer_out.ndims() == 5 && consumer_out.is_strided()) {
1593+
const auto &ori_strides = consumer_out.vstrides();
1594+
std::vector<dim_t> strides = {ori_strides[0], ori_strides[2],
1595+
ori_strides[3], ori_strides[4]};
1596+
dnnl::memory::desc tmp_md {ltw(out_lt).vdims(),
1597+
static_cast<dnnl::memory::data_type>(
1598+
ltw(out_lt).data_type()),
1599+
strides};
1600+
expected_md = tmp_md;
1601+
}
1602+
} else {
1603+
dnnl::memory::desc expected_md {ltw(out_lt).vdims(),
1604+
static_cast<dnnl::memory::data_type>(
1605+
ltw(out_lt).data_type()),
1606+
dnnl::memory::format_tag::acbd};
1607+
}
1608+
} else {
1609+
expected_md = make_dnnl_memory_desc(out_lt);
1610+
}
1611+
status_t status = fill_layout_info(dst_val, expected_md);
15841612

15851613
// fill scratchpads dimensions and data type to scratchpad value_t
15861614
value_ptr scratchpad_val = op->get_output_value(1);

src/graph/backend/dnnl/passes/insert_ops.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ status_t insert_reshape_for_sdpa(std::shared_ptr<subgraph_t> &sg) {
662662
reshape_output->set_attr<bool>(op_attr::special_zero, false);
663663
reshape_output->set_attr<std::vector<int64_t>>(
664664
op_attr::shape, expected_output_dims);
665+
665666
rewriter.insert_op_after(reshape_output, cur_op, 0);
666667
}
667668
rewriter.run();

src/graph/backend/dnnl/passes/transform.cpp

+17-10
Original file line numberDiff line numberDiff line change
@@ -3835,13 +3835,16 @@ impl::status_t fuse_src_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg) {
38353835
return impl::status::success;
38363836
}
38373837

3838-
impl::status_t fuse_dst_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg) {
3838+
impl::status_t fuse_dst_transpose_to_predecessor(
3839+
std::shared_ptr<subgraph_t> &sg) {
38393840
std::vector<op_ptr> transpose_ops;
38403841
for (auto &cur_op : sg->get_ops()) {
38413842
if (cur_op->get_kind() == op_kind::dnnl_transpose
38423843
&& cur_op->get_input_value(0)->has_producer()
3843-
&& cur_op->get_input_value(0)->get_producer().get_kind()
3844-
== op_kind::dnnl_matmul
3844+
&& (cur_op->get_input_value(0)->get_producer().get_kind()
3845+
== op_kind::dnnl_matmul
3846+
|| cur_op->get_input_value(0)->get_producer().get_kind()
3847+
== op_kind::dnnl_sdpa)
38453848
&& !cur_op->get_output_value(0)->get_consumers().empty()
38463849
&& (cur_op->get_output_value(0)
38473850
->get_consumers()[0]
@@ -3894,13 +3897,17 @@ impl::status_t fuse_dst_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg) {
38943897
dnnl::memory::desc expected_out_md = out_md.permute_axes(axes);
38953898
// Special check to avoid low matmul performance with adbc layout.
38963899
// TODO: remove this once the performance is improved.
3897-
if (get_format_tag(expected_out_md) == dnnl::memory::format_tag::adbc) {
3900+
if (in_val->get_producer().get_kind() == op_kind::dnnl_matmul
3901+
&& get_format_tag(expected_out_md)
3902+
== dnnl::memory::format_tag::adbc) {
38983903
break;
38993904
}
39003905
const auto &strides = expected_out_md.get_strides();
39013906
in_val->set_strides(strides);
3902-
auto &matmul = transpose_op->get_input_value(0)->get_producer();
3903-
matmul.set_attr(op_attr::keep_dst_layout, true);
3907+
if (in_val->get_producer().get_kind() == op_kind::dnnl_matmul) {
3908+
auto &matmul = in_val->get_producer();
3909+
matmul.set_attr(op_attr::keep_dst_layout, true);
3910+
}
39043911
}
39053912
rewriter.run();
39063913
return impl::status::success;
@@ -4182,12 +4189,12 @@ status_t fuse_sdpa(std::shared_ptr<subgraph_t> &sg) {
41824189
switch (walker->get_kind()) {
41834190
case op_kind::dnnl_matmul: {
41844191
if (pattern_ops.size() == 1) {
4185-
}
4192+
}
41864193
// Finish pattern match process after second matmul
41874194
else {
41884195
valid_pattern = (pattern_ops.size() >= 3);
41894196
finished = true;
4190-
}
4197+
}
41914198
break;
41924199
}
41934200
case op_kind::dnnl_binary: {
@@ -4256,8 +4263,8 @@ status_t fuse_sdpa(std::shared_ptr<subgraph_t> &sg) {
42564263
auto alg = static_cast<dnnl::algorithm>(
42574264
op->get_attr<int64_t>(op_attr::alg_kind));
42584265
// handle scale
4259-
if (alg == dnnl::algorithm::binary_mul ||
4260-
alg == dnnl::algorithm::binary_div) {
4266+
if (alg == dnnl::algorithm::binary_mul
4267+
|| alg == dnnl::algorithm::binary_div) {
42614268
auto scale_val = op->get_input_value(1);
42624269
scale_val->remove_consumer(*op, 1);
42634270
sdpa_op->connect_input(input_idx++, scale_val);

src/graph/backend/dnnl/passes/transform.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ impl::status_t fuse_src_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg);
206206

207207
// This pass will compute matmul with the dst layout of following transpose if
208208
// the operator after transpose need a dense layout
209-
impl::status_t fuse_dst_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg);
209+
impl::status_t fuse_dst_transpose_to_predecessor(
210+
std::shared_ptr<subgraph_t> &sg);
210211

211212
// This pass will fuse all the reshape to its lead op for GQA.
212213
impl::status_t fuse_reshape_for_gqa(std::shared_ptr<subgraph_t> &sg);

0 commit comments

Comments
 (0)