Skip to content

Commit 35098b4

Browse files
committed
graph: dnnl: refine GQA pattern and uekrnel support
1 parent 9d91c19 commit 35098b4

File tree

5 files changed

+478
-3
lines changed

5 files changed

+478
-3
lines changed

src/graph/backend/dnnl/kernels/sdp_primitive_v1.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,13 @@ status_t sdp_primitive_v1_kernel_t<quantized>::compile_impl(
7070

7171
BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
7272
BACKEND_DNNL_ADD_PASS(pipeline, fuse_implicit_causal_mask);
73-
BACKEND_DNNL_ADD_PASS(pipeline, fuse_reshape_for_gqa);
74-
BACKEND_DNNL_ADD_PASS(pipeline, binary_canonicalization);
7573
BACKEND_DNNL_ADD_PASS(pipeline, insert_permute_for_matmul);
7674

7775
pipeline.reset_visualize_arg(true, false);
7876
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
7977
BACKEND_DNNL_ADD_PASS(pipeline, fuse_src_transpose_to_matmul);
8078
BACKEND_DNNL_ADD_PASS(pipeline, fuse_sdpa);
79+
BACKEND_DNNL_ADD_PASS(pipeline, insert_reshape_for_sdpa);
8180

8281
// TODO(GX):add fuse dst transpose to sdpa
8382
// BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);

src/graph/backend/dnnl/layout_propagator.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,7 @@ status_t layout_propagator_for_sdpa(std::shared_ptr<op_t> &op,
15791579
auto dst_md = make_dnnl_memory_desc(
15801580
op->get_output_value(0)->get_logical_tensor());
15811581
value_ptr dst_val = op->get_output_value(0);
1582+
dst_val->set_strides(get_dense_strides(dst_md.get_dims()));
15821583
status_t status = fill_layout_info(dst_val, dst_md);
15831584

15841585
// fill scratchpads dimensions and data type to scratchpad value_t

src/graph/backend/dnnl/passes/insert_ops.cpp

+97
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,103 @@ status_t insert_reshape_for_ndx2d_matmul(std::shared_ptr<subgraph_t> &sg) {
571571
return infer_shape(sg);
572572
}
573573

574+
status_t insert_reshape_for_sdpa(std::shared_ptr<subgraph_t> &sg) {
575+
subgraph_rewriter_t rewriter(sg);
576+
577+
for (auto &cur_op : sg->get_ops()) {
578+
if (cur_op->get_kind() != op_kind::dnnl_sdpa) continue;
579+
580+
int32_t query_ndims
581+
= cur_op->get_input_value(0)->get_logical_tensor().ndims;
582+
if (query_ndims != 5) continue;
583+
584+
// Insert reshape for Query
585+
auto query_dims = logical_tensor_wrapper_t(
586+
cur_op->get_input_value(0)->get_logical_tensor())
587+
.vdims();
588+
dims expected_query_dims {
589+
query_dims[0], -1, query_dims[3], query_dims[4]};
590+
op_ptr reshape_query = std::make_shared<op_t>(op_kind::dnnl_reshape);
591+
reshape_query->set_attr<bool>(op_attr::special_zero, false);
592+
reshape_query->set_attr<std::vector<int64_t>>(
593+
op_attr::shape, expected_query_dims);
594+
rewriter.insert_op_before(reshape_query, cur_op, 0);
595+
596+
// Insert reshape for Key
597+
auto key_dims = logical_tensor_wrapper_t(
598+
cur_op->get_input_value(1)->get_logical_tensor())
599+
.vdims();
600+
dims expected_key_dims {key_dims[0], -1, key_dims[3], key_dims[4]};
601+
op_ptr reshape_key = std::make_shared<op_t>(op_kind::dnnl_reshape);
602+
reshape_key->set_attr<bool>(op_attr::special_zero, false);
603+
reshape_key->set_attr<std::vector<int64_t>>(
604+
op_attr::shape, expected_key_dims);
605+
rewriter.insert_op_before(reshape_key, cur_op, 1);
606+
607+
// Insert reshape for value
608+
auto value_dims = logical_tensor_wrapper_t(
609+
cur_op->get_input_value(2)->get_logical_tensor())
610+
.vdims();
611+
dims expected_value_dims {
612+
value_dims[0], -1, value_dims[3], value_dims[4]};
613+
op_ptr reshape_value = std::make_shared<op_t>(op_kind::dnnl_reshape);
614+
reshape_value->set_attr<bool>(op_attr::special_zero, false);
615+
reshape_value->set_attr<std::vector<int64_t>>(
616+
op_attr::shape, expected_value_dims);
617+
rewriter.insert_op_before(reshape_value, cur_op, 2);
618+
619+
// Insert reshape for scale
620+
if (cur_op->get_attr<bool>(op_attr::with_scale)) {
621+
int32_t scale_ndims
622+
= cur_op->get_input_value(3)->get_logical_tensor().ndims;
623+
if (scale_ndims == 5) {
624+
auto scale_dims = logical_tensor_wrapper_t(
625+
cur_op->get_input_value(3)->get_logical_tensor())
626+
.vdims();
627+
dims expected_scale_dims {
628+
scale_dims[0], -1, scale_dims[3], scale_dims[4]};
629+
op_ptr reshape_scale
630+
= std::make_shared<op_t>(op_kind::dnnl_reshape);
631+
reshape_scale->set_attr<bool>(op_attr::special_zero, false);
632+
reshape_scale->set_attr<std::vector<int64_t>>(
633+
op_attr::shape, expected_scale_dims);
634+
rewriter.insert_op_before(reshape_scale, cur_op, 3);
635+
}
636+
}
637+
// Insert reshape for mask
638+
if (cur_op->get_attr<bool>(op_attr::with_mask)) {
639+
int32_t mask_ndims
640+
= cur_op->get_input_value(4)->get_logical_tensor().ndims;
641+
if (mask_ndims == 5) {
642+
auto mask_dims = logical_tensor_wrapper_t(
643+
cur_op->get_input_value(4)->get_logical_tensor())
644+
.vdims();
645+
dims expected_mask_dims {
646+
mask_dims[0], -1, mask_dims[3], mask_dims[4]};
647+
op_ptr reshape_mask
648+
= std::make_shared<op_t>(op_kind::dnnl_reshape);
649+
reshape_mask->set_attr<bool>(op_attr::special_zero, false);
650+
reshape_mask->set_attr<std::vector<int64_t>>(
651+
op_attr::shape, expected_mask_dims);
652+
rewriter.insert_op_before(reshape_mask, cur_op, 4);
653+
}
654+
}
655+
656+
// Insert reshape for output
657+
auto output_dims = logical_tensor_wrapper_t(
658+
cur_op->get_output_value(0)->get_logical_tensor())
659+
.vdims();
660+
dims expected_output_dims {output_dims};
661+
op_ptr reshape_output = std::make_shared<op_t>(op_kind::dnnl_reshape);
662+
reshape_output->set_attr<bool>(op_attr::special_zero, false);
663+
reshape_output->set_attr<std::vector<int64_t>>(
664+
op_attr::shape, expected_output_dims);
665+
rewriter.insert_op_after(reshape_output, cur_op, 0);
666+
}
667+
rewriter.run();
668+
return infer_shape(sg);
669+
}
670+
574671
status_t insert_unsqueeze_and_squeeze_for_matmul(
575672
std::shared_ptr<subgraph_t> &sg) {
576673
subgraph_rewriter_t rewriter(sg);

src/graph/backend/dnnl/passes/insert_ops.hpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2021-2024 Intel Corporation
2+
* Copyright 2021-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -58,6 +58,11 @@ status_t insert_permute_for_matmul(std::shared_ptr<subgraph_t> &sg);
5858
/// 2) reshape dst back to nd after compilation
5959
status_t insert_reshape_for_ndx2d_matmul(std::shared_ptr<subgraph_t> &sg);
6060

61+
/// Insert reshape for 5D sdpa. sdpa only support 4D input/output
62+
/// 1) reshape Q/K/V/scale/mask from 5D to 4D
63+
/// 2) reshape output from 4D to 5D
64+
status_t insert_reshape_for_sdpa(std::shared_ptr<subgraph_t> &sg);
65+
6166
// Insert an unsqueeze-squeeze pair for matmul
6267
//
6368
// The usage of unsqueeze op:

0 commit comments

Comments
 (0)