Skip to content

Commit b89f51b

Browse files
committedMar 21, 2025
graph: dnnl: add sdpa primitive ukernel v1
1 parent b5d090f commit b89f51b

12 files changed

+524
-17
lines changed
 

‎src/graph/backend/dnnl/kernels/large_partition.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ void larger_partition_kernel_t::setup_pipeline_stage2(pass_pipeline_t &pipeline,
142142
}
143143
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
144144
BACKEND_DNNL_ADD_PASS(pipeline, fuse_src_transpose_to_matmul);
145-
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
145+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
146146
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
147147
BACKEND_DNNL_ADD_PASS(pipeline, common_reorder_elimination);
148148
BACKEND_DNNL_ADD_PASS(pipeline, fuse_adjacent_reorders);

‎src/graph/backend/dnnl/kernels/matmul.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ status_t matmul_t<quantized>::compile_impl(const dnnl_partition_impl_t *part,
110110
}
111111

112112
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
113-
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
113+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
114114
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
115115

116116
BACKEND_DNNL_ADD_PASS(pipeline, fuse_adjacent_reorders);

‎src/graph/backend/dnnl/kernels/mqa_decomp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ status_t mqa_decomp_kernel_t<quantized, dt>::compile_impl(
8787
BACKEND_DNNL_ADD_PASS(pipeline, remove_quant_data_with_no_effect);
8888
}
8989
pipeline.reset_visualize_arg(true, false);
90-
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
90+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
9191
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
9292

9393
// Run the added passes

‎src/graph/backend/dnnl/kernels/sdp.hpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2024 Intel Corporation
2+
* Copyright 2024-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -27,6 +27,7 @@
2727
#include "graph/backend/dnnl/kernels/large_partition.hpp"
2828
#include "graph/backend/dnnl/kernels/sdp_decomp.hpp"
2929
#include "graph/backend/dnnl/kernels/sdp_primitive.hpp"
30+
#include "graph/backend/dnnl/kernels/sdp_primitive_v1.hpp"
3031

3132
#include "graph/backend/dnnl/dnnl_partition_impl.hpp"
3233

@@ -65,7 +66,15 @@ struct sdp_base_t : public kernel_base_t {
6566

6667
status_t ret = status::unimplemented;
6768

69+
// SDPA Ukernel v1 with fused internal sdpa solution. Support fload sdpa
70+
// only.
71+
// TODO(GX): Support quantized sdpa and merge with sdp_primitive_kernel_t.
6872
if (enable_ukernel) {
73+
kernel = std::make_shared<sdp_primitive_v1_kernel_t<quantized>>();
74+
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
75+
}
76+
77+
if (ret != status::success && enable_ukernel) {
6978
kernel = std::make_shared<sdp_primitive_kernel_t<quantized>>();
7079
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
7180
}

‎src/graph/backend/dnnl/kernels/sdp_decomp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ status_t sdp_decomp_kernel_t<quantized, dt>::compile_impl(
8686
BACKEND_DNNL_ADD_PASS(pipeline, remove_quant_data_with_no_effect);
8787
}
8888
pipeline.reset_visualize_arg(true, false);
89-
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
89+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
9090
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
9191

9292
// Run the added passes

‎src/graph/backend/dnnl/kernels/sdp_primitive.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ status_t sdp_primitive_kernel_t<quantized>::compile_impl(
9292

9393
pipeline.reset_visualize_arg(true, false);
9494
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
95-
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_matmul);
95+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
9696
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
9797

9898
// bind the memory for each op

‎src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp

+19-1
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,29 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
166166

167167
status_t sdp_primitive_config_t::initial_check(
168168
const std::shared_ptr<subgraph_t> &sg,
169-
const std::vector<logical_tensor_t> &inputs) {
169+
const std::vector<logical_tensor_t> &inputs, bool v1_kenrel) {
170170
// At least 3 inputs: Q, K, V
171171
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
172172
"At least 3 inputs are required");
173173

174+
// Ukernel doesn't support f32 datatype now
175+
VCHECK_SDP_PRIMITIVE(inputs[0].data_type != dnnl_data_type_t::dnnl_f32,
176+
status::invalid_arguments,
177+
"SDPA ukernel doesn't support f32 datatype now");
178+
179+
// Note: sdpa_primitive_v1 kernel currently don't support legacy GQA pattern.
180+
if (v1_kenrel) {
181+
for (auto &cur_op : sg->get_ops()) {
182+
if (cur_op->get_kind() == graph::op_kind::StaticReshape) {
183+
auto in = cur_op->get_input_value(0)->get_logical_tensor();
184+
auto out = cur_op->get_output_value(0)->get_logical_tensor();
185+
if (ltw(in).ndims() == 5 || ltw(out).ndims() == 5) {
186+
return status::unimplemented;
187+
}
188+
}
189+
}
190+
}
191+
174192
// step1(pattern check): Not support sdpa variants with select as mask
175193
// We already have a pattern matcher to ensure that the sdpa patterns
176194
// dispatch to here are knows ones, and we have quant check in sdpa base

‎src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ struct sdp_primitive_config_t {
8282
// 2. only support fp16 data type
8383
// 3. only support 4-dims tensor
8484
status_t initial_check(const std::shared_ptr<subgraph_t> &sg,
85-
const std::vector<logical_tensor_t> &inputs);
85+
const std::vector<logical_tensor_t> &inputs,
86+
bool v1_kenrel = false);
8687

8788
// Initialize parameters and primitive.
8889
status_t init(std::shared_ptr<subgraph_t> &sg, const dnnl::engine &p_engine,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
/*******************************************************************************
2+
* Copyright 2024-2025 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include "graph/backend/dnnl/kernels/sdp_primitive_v1.hpp"
18+
19+
#include "common/sdpa_pd.hpp"
20+
21+
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
22+
#include "gpu/intel/ocl/stream.hpp"
23+
#elif DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
24+
#include "gpu/intel/sycl/stream.hpp"
25+
#endif
26+
27+
#include "graph/backend/dnnl/passes/compile_ops.hpp"
28+
#include "graph/backend/dnnl/passes/constant_propagation.hpp"
29+
#include "graph/backend/dnnl/passes/insert_ops.hpp"
30+
#include "graph/backend/dnnl/passes/layout_propagation.hpp"
31+
#include "graph/backend/dnnl/passes/lower.hpp"
32+
#include "graph/backend/dnnl/passes/memory_planning.hpp"
33+
#include "graph/backend/dnnl/passes/transform.hpp"
34+
#include "graph/backend/dnnl/passes/utils.hpp"
35+
36+
#include "graph/backend/dnnl/op_executable.hpp"
37+
38+
namespace dnnl {
39+
namespace impl {
40+
namespace graph {
41+
namespace dnnl_impl {
42+
43+
template <bool quantized>
44+
status_t sdp_primitive_v1_kernel_t<quantized>::compile_impl(
45+
const dnnl_partition_impl_t *part, const engine_t *g_engine,
46+
const std::vector<logical_tensor_t> &inputs,
47+
const std::vector<logical_tensor_t> &outputs) {
48+
// sdp_primitive_v1_kernel_t only supports Intel GPU.
49+
#if defined(DNNL_WITH_SYCL) && DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL
50+
return status::unimplemented;
51+
#endif
52+
if (quantized) { return status::unimplemented; }
53+
54+
p_engine_ = make_dnnl_engine(*g_engine);
55+
g_alloc_
56+
= reinterpret_cast<graph::allocator_t *>(g_engine->get_allocator());
57+
58+
// First, dry run on a deep copy
59+
subgraph_
60+
= std::make_shared<subgraph_t>(graph_t::deep_copy(part->get_ops()),
61+
p_engine_, part->get_fpmath_mode(), false, true);
62+
CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs));
63+
64+
CHECK(cfg_.initial_check(subgraph_, inputs, true));
65+
66+
subgraph_visualizer_t vis(part->id(), [this](const value_t *val) {
67+
return this->memory_planner_.get_memory_info(val);
68+
});
69+
pass_pipeline_t pipeline = pass_pipeline_t(vis);
70+
71+
BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
72+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_implicit_causal_mask);
73+
BACKEND_DNNL_ADD_PASS(pipeline, binary_canonicalization);
74+
BACKEND_DNNL_ADD_PASS(pipeline, insert_permute_for_matmul);
75+
76+
pipeline.reset_visualize_arg(true, false);
77+
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
78+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_src_transpose_to_matmul);
79+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_sdpa);
80+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
81+
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
82+
83+
// bind the memory for each op`
84+
auto memory_plan = [&](std::shared_ptr<subgraph_t> &sg) {
85+
return memory_planner_.run(sg);
86+
};
87+
pipeline.reset_visualize_arg(true, true);
88+
BACKEND_DNNL_ADD_PASS(pipeline, memory_plan);
89+
BACKEND_DNNL_ADD_PASS(pipeline, compile_ops);
90+
91+
// Run the added passes
92+
BACKEND_DNNL_CHECK(pipeline.run(subgraph_));
93+
94+
// fill information for inputs logical tensors
95+
for (size_t i = 0; i < inputs.size(); i++) {
96+
auto &in = const_cast<logical_tensor_t &>(inputs[i]);
97+
in = subgraph_->ins_[i];
98+
}
99+
100+
// fill information for outputs logical tensors
101+
for (size_t i = 0; i < outputs.size(); i++) {
102+
auto &out = const_cast<logical_tensor_t &>(outputs[i]);
103+
out = subgraph_->outs_[i];
104+
}
105+
106+
resource_ctor_ = [this]() {
107+
return this->memory_planner_.get_exec_args_set().clone();
108+
};
109+
110+
return status::success;
111+
}
112+
113+
template <bool quantized>
114+
void sdp_primitive_v1_kernel_t<quantized>::prepare_args_set(
115+
const execution_args_set_t *res, const std::vector<tensor_t> &inputs,
116+
const std::vector<tensor_t> &outputs, const scratchpad_t &scratchpad) {
117+
// update the data of partition in/outputs args
118+
for (const auto &mem_idx : res->get_mems_use_external_inputs()) {
119+
mem_idx.first.set_data_handle(inputs[mem_idx.second].get_data_handle());
120+
}
121+
for (const auto &mem_idx : res->get_mems_use_external_outputs()) {
122+
mem_idx.first.set_data_handle(
123+
outputs[mem_idx.second].get_data_handle());
124+
}
125+
126+
grantor_t var_grantor = memory_planner_.internal_temporary_grantor(
127+
scratchpad.get_buffer());
128+
129+
for (auto &mem_offkey : res->get_mems_use_internal_temporary()) {
130+
mem_offkey.first.set_data_handle(var_grantor.get(mem_offkey.second));
131+
}
132+
}
133+
134+
template <bool quantized>
135+
status_t sdp_primitive_v1_kernel_t<quantized>::execute_impl(
136+
const stream_t *g_stream, const std::vector<tensor_t> &inputs,
137+
const std::vector<tensor_t> &outputs) {
138+
dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream);
139+
140+
thread_local_cache_t<execution_args_set_t> res_cache;
141+
execution_args_set_t *res = res_cache.get_or_add(
142+
reinterpret_cast<size_t>(this), resource_ctor_);
143+
144+
temporary_scratchpad_t scratchpad(
145+
memory_planner_.total_internal_temporary_size(), p_engine_,
146+
*g_alloc_);
147+
prepare_args_set(res, inputs, outputs, scratchpad);
148+
149+
for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
150+
subgraph_->execs_[i]->execute(p_stream, res->get_exec_args()[i]);
151+
}
152+
153+
return status::success;
154+
}
155+
156+
#ifdef DNNL_WITH_SYCL
157+
template <bool quantized>
158+
status_t sdp_primitive_v1_kernel_t<quantized>::sycl_execute_impl(
159+
const stream_t *g_stream, const std::vector<tensor_t> &inputs,
160+
const std::vector<tensor_t> &outputs,
161+
const std::vector<::sycl::event> &sycl_deps,
162+
::sycl::event *sycl_event) {
163+
// sdp_primitive_v1_kernel_t only supports Intel GPU.
164+
#if DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL
165+
return status::unimplemented;
166+
#endif
167+
auto deps = sycl_deps;
168+
::sycl::event returned_event;
169+
170+
dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream);
171+
172+
thread_local_cache_t<execution_args_set_t> res_cache;
173+
execution_args_set_t *res = res_cache.get_or_add(
174+
reinterpret_cast<size_t>(this), resource_ctor_);
175+
176+
temporary_scratchpad_t scratchpad(
177+
memory_planner_.total_internal_temporary_size(), p_engine_,
178+
*g_alloc_);
179+
prepare_args_set(res, inputs, outputs, scratchpad);
180+
181+
for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
182+
if (subgraph_->is_constant_[i]) continue;
183+
returned_event = subgraph_->execs_[i]->execute_sycl(
184+
p_stream, res->get_exec_args()[i], deps);
185+
deps = {returned_event};
186+
}
187+
188+
scratchpad.set_deps(returned_event);
189+
if (sycl_event) *sycl_event = returned_event;
190+
191+
return status::success;
192+
}
193+
#endif
194+
195+
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
196+
template <bool quantized>
197+
status_t sdp_primitive_v1_kernel_t<quantized>::ocl_execute_impl(
198+
const stream_t *g_stream, const std::vector<tensor_t> &inputs,
199+
const std::vector<tensor_t> &outputs,
200+
const std::vector<cl_event> &cl_deps, cl_event *ret_event) {
201+
// sdp_primitive_v1_kernel_t only supports Intel GPU.
202+
#if DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL
203+
return status::unimplemented;
204+
#endif
205+
auto deps = cl_deps;
206+
cl_event returned_event {};
207+
208+
dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream);
209+
210+
thread_local_cache_t<execution_args_set_t> res_cache;
211+
execution_args_set_t *res = res_cache.get_or_add(
212+
reinterpret_cast<size_t>(this), resource_ctor_);
213+
214+
temporary_scratchpad_t scratchpad(
215+
memory_planner_.total_internal_temporary_size(), p_engine_,
216+
*g_alloc_);
217+
prepare_args_set(res, inputs, outputs, scratchpad);
218+
219+
for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
220+
if (subgraph_->is_constant_[i]) continue;
221+
returned_event = subgraph_->execs_[i]->execute_ocl(
222+
p_stream, res->get_exec_args()[i], deps);
223+
deps = {returned_event};
224+
}
225+
226+
scratchpad.set_deps(returned_event);
227+
if (ret_event) *ret_event = returned_event;
228+
229+
return status::success;
230+
}
231+
#endif
232+
233+
template struct sdp_primitive_v1_kernel_t<false>;
234+
template struct sdp_primitive_v1_kernel_t<true>;
235+
236+
} // namespace dnnl_impl
237+
} // namespace graph
238+
} // namespace impl
239+
} // namespace dnnl

0 commit comments

Comments
 (0)