Skip to content

Commit 9f084f6

Browse files
committed
graph: dnnl: comments fix
1 parent 6145d95 commit 9f084f6

9 files changed

+46
-43
lines changed

src/graph/backend/dnnl/dnnl_op_def.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,9 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_mask, 1,
11341134
.SET_EXECUTABLE_CREATOR(executable_creator<memory_reparser_t>)
11351135
.SET_ARG_INDICES_GETTER(memory_reparser_t))
11361136

1137+
// The data types of query/key/value/mask/output must be consistent, and only
1138+
// f16/bf16 are supported. The data type of scale must be consistent with other
1139+
// input and output data types or fp32.
11371140
DNNL_GRAPH_OP_SCHEMA(dnnl_sdpa, 1,
11381141
op_schema_t()
11391142
.set_inputs_option(op_schema_t::param_num_option::variadic)
@@ -1152,8 +1155,6 @@ DNNL_GRAPH_OP_SCHEMA(dnnl_sdpa, 1,
11521155
.set_attr(op_attr::with_mask, true, attribute_kind::b)
11531156
// with_causal attribute support top-left mask type only
11541157
.set_attr(op_attr::with_causal, true, attribute_kind::b)
1155-
.set_attr(op_attr::fusion_info_key, false, attribute_kind::i,
1156-
(int64_t)-1)
11571158
.set_shape_inference_function(infer_dnnl_sdpa_output_shape)
11581159
.SET_LAYOUT_PROPAGATOR(layout_propagator_for_sdpa)
11591160
.SET_EXECUTABLE_CREATOR(executable_creator<sdpa_executable_t>)

src/graph/backend/dnnl/kernels/sdp.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ struct sdp_base_t : public kernel_base_t {
6969
// SDPA Ukernel v1 with fused internal sdpa solution. Support fload sdpa
7070
// only.
7171
// TODO(GX): Support quantized sdpa and merge with sdp_primitive_kernel_t.
72-
if (enable_ukernel) {
73-
kernel = std::make_shared<sdp_primitive_v1_kernel_t<quantized>>();
72+
if (enable_ukernel && !quantized) {
73+
kernel = std::make_shared<sdp_primitive_v1_kernel_t>();
7474
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
7575
}
7676

src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
166166

167167
status_t sdp_primitive_config_t::initial_check(
168168
const std::shared_ptr<subgraph_t> &sg,
169-
const std::vector<logical_tensor_t> &inputs, bool v1_kenrel) {
169+
const std::vector<logical_tensor_t> &inputs, bool v1_kernel) {
170170
// At least 3 inputs: Q, K, V
171171
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
172172
"At least 3 inputs are required");
@@ -177,7 +177,7 @@ status_t sdp_primitive_config_t::initial_check(
177177
"SDPA ukernel doesn't support f32 datatype now");
178178

179179
// Note: sdpa_primitive_v1 kernel currently don't support legacy GQA pattern.
180-
if (v1_kenrel) {
180+
if (v1_kernel) {
181181
for (auto &cur_op : sg->get_ops()) {
182182
if (cur_op->get_kind() == graph::op_kind::StaticReshape) {
183183
auto in = cur_op->get_input_value(0)->get_logical_tensor();

src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct sdp_primitive_config_t {
8383
// 3. only support 4-dims tensor
8484
status_t initial_check(const std::shared_ptr<subgraph_t> &sg,
8585
const std::vector<logical_tensor_t> &inputs,
86-
bool v1_kenrel = false);
86+
bool v1_kernel = false);
8787

8888
// Initialize parameters and primitive.
8989
status_t init(std::shared_ptr<subgraph_t> &sg, const dnnl::engine &p_engine,

src/graph/backend/dnnl/kernels/sdp_primitive_v1.cpp

+10-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2024-2025 Intel Corporation
2+
* Copyright 2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -40,16 +40,14 @@ namespace impl {
4040
namespace graph {
4141
namespace dnnl_impl {
4242

43-
template <bool quantized>
44-
status_t sdp_primitive_v1_kernel_t<quantized>::compile_impl(
43+
status_t sdp_primitive_v1_kernel_t::compile_impl(
4544
const dnnl_partition_impl_t *part, const engine_t *g_engine,
4645
const std::vector<logical_tensor_t> &inputs,
4746
const std::vector<logical_tensor_t> &outputs) {
4847
// sdp_primitive_v1_kernel_t only supports Intel GPU.
4948
#if defined(DNNL_WITH_SYCL) && DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL
5049
return status::unimplemented;
5150
#endif
52-
if (quantized) { return status::unimplemented; }
5351

5452
p_engine_ = make_dnnl_engine(*g_engine);
5553
g_alloc_
@@ -110,8 +108,7 @@ status_t sdp_primitive_v1_kernel_t<quantized>::compile_impl(
110108
return status::success;
111109
}
112110

113-
template <bool quantized>
114-
void sdp_primitive_v1_kernel_t<quantized>::prepare_args_set(
111+
void sdp_primitive_v1_kernel_t::prepare_args_set(
115112
const execution_args_set_t *res, const std::vector<tensor_t> &inputs,
116113
const std::vector<tensor_t> &outputs, const scratchpad_t &scratchpad) {
117114
// update the data of partition in/outputs args
@@ -131,9 +128,8 @@ void sdp_primitive_v1_kernel_t<quantized>::prepare_args_set(
131128
}
132129
}
133130

134-
template <bool quantized>
135-
status_t sdp_primitive_v1_kernel_t<quantized>::execute_impl(
136-
const stream_t *g_stream, const std::vector<tensor_t> &inputs,
131+
status_t sdp_primitive_v1_kernel_t::execute_impl(const stream_t *g_stream,
132+
const std::vector<tensor_t> &inputs,
137133
const std::vector<tensor_t> &outputs) {
138134
dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream);
139135

@@ -154,9 +150,8 @@ status_t sdp_primitive_v1_kernel_t<quantized>::execute_impl(
154150
}
155151

156152
#ifdef DNNL_WITH_SYCL
157-
template <bool quantized>
158-
status_t sdp_primitive_v1_kernel_t<quantized>::sycl_execute_impl(
159-
const stream_t *g_stream, const std::vector<tensor_t> &inputs,
153+
status_t sdp_primitive_v1_kernel_t::sycl_execute_impl(const stream_t *g_stream,
154+
const std::vector<tensor_t> &inputs,
160155
const std::vector<tensor_t> &outputs,
161156
const std::vector<::sycl::event> &sycl_deps,
162157
::sycl::event *sycl_event) {
@@ -193,9 +188,8 @@ status_t sdp_primitive_v1_kernel_t<quantized>::sycl_execute_impl(
193188
#endif
194189

195190
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
196-
template <bool quantized>
197-
status_t sdp_primitive_v1_kernel_t<quantized>::ocl_execute_impl(
198-
const stream_t *g_stream, const std::vector<tensor_t> &inputs,
191+
status_t sdp_primitive_v1_kernel_t::ocl_execute_impl(const stream_t *g_stream,
192+
const std::vector<tensor_t> &inputs,
199193
const std::vector<tensor_t> &outputs,
200194
const std::vector<cl_event> &cl_deps, cl_event *ret_event) {
201195
// sdp_primitive_v1_kernel_t only supports Intel GPU.
@@ -230,8 +224,7 @@ status_t sdp_primitive_v1_kernel_t<quantized>::ocl_execute_impl(
230224
}
231225
#endif
232226

233-
template struct sdp_primitive_v1_kernel_t<false>;
234-
template struct sdp_primitive_v1_kernel_t<true>;
227+
struct sdp_primitive_v1_kernel_t;
235228

236229
} // namespace dnnl_impl
237230
} // namespace graph

src/graph/backend/dnnl/kernels/sdp_primitive_v1.hpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2024-2025 Intel Corporation
2+
* Copyright 2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -14,8 +14,8 @@
1414
* limitations under the License.
1515
*******************************************************************************/
1616

17-
#ifndef GRAPH_BACKEND_DNNL_KERNELS_sdp_primitive_v1_HPP
18-
#define GRAPH_BACKEND_DNNL_KERNELS_sdp_primitive_v1_HPP
17+
#ifndef GRAPH_BACKEND_DNNL_KERNELS_SDP_PRIMITIVE_V1_HPP
18+
#define GRAPH_BACKEND_DNNL_KERNELS_SDP_PRIMITIVE_V1_HPP
1919

2020
#include <algorithm>
2121
#include <memory>
@@ -40,7 +40,6 @@ namespace impl {
4040
namespace graph {
4141
namespace dnnl_impl {
4242

43-
template <bool quantized>
4443
struct sdp_primitive_v1_kernel_t : public kernel_base_t {
4544
private:
4645
allocator_t *g_alloc_ = nullptr;

src/graph/backend/dnnl/layout_propagator.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -1580,7 +1580,8 @@ status_t layout_propagator_for_sdpa(std::shared_ptr<op_t> &op,
15801580
const logical_tensor_t &out_lt = dst_val->get_logical_tensor();
15811581

15821582
dnnl::memory::desc expected_md;
1583-
// Set default output layout format for sdpa as acbd
1583+
// Set default output layout format for sdpa as acbd if user doesn't specify
1584+
// the layout since no reorder will required after sdpa.
15841585
if (ltw(out_lt).is_any()) {
15851586
expected_md = {ltw(out_lt).vdims(),
15861587
static_cast<dnnl::memory::data_type>(ltw(out_lt).data_type()),

src/graph/backend/dnnl/op_executable.hpp

+11-3
Original file line numberDiff line numberDiff line change
@@ -2689,12 +2689,19 @@ struct sdpa_executable_t : public op_executable_t {
26892689
md_k.get(), md_v.get(), md_dst.get(), md_mask.get(), scale_dt,
26902690
is_invert_scale_, kv_head_number, is_causal_mask_, attr.get());
26912691
if (s != dnnl::impl::status::success) {
2692-
throw std::runtime_error("create_sdpa_pd failed");
2692+
is_initialized_ = false;
2693+
} else {
2694+
status_t s = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get());
2695+
if (s != dnnl::impl::status::success) {
2696+
is_initialized_ = false;
2697+
} else {
2698+
is_initialized_ = true;
2699+
}
26932700
}
2694-
2695-
sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get());
26962701
}
26972702

2703+
bool is_initialized() { return is_initialized_; }
2704+
26982705
void execute(const stream &stream,
26992706
const std::unordered_map<int, memory> &args) const override {
27002707
exec_args_t exec_args;
@@ -2819,6 +2826,7 @@ struct sdpa_executable_t : public op_executable_t {
28192826
bool with_mask_;
28202827
bool is_invert_scale_;
28212828
bool is_causal_mask_;
2829+
bool is_initialized_;
28222830
};
28232831

28242832
} // namespace dnnl_impl

src/graph/backend/dnnl/passes/compile_ops.cpp

+12-11
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,20 @@ status_t compile_ops(std::shared_ptr<subgraph_t> &sg) {
6060
auto creator = opm->get_additional_item<executable_creator_func>(
6161
"executable_creator");
6262

63-
try {
64-
std::shared_ptr<op_executable_t> exec
65-
= creator(cur_op, p_engine, mgr, pd_cache);
66-
VCHECK_COMPILE_OPS(exec != nullptr, status::invalid_graph_op,
67-
"unimplemented op, can't compile op %s",
63+
std::shared_ptr<op_executable_t> exec
64+
= creator(cur_op, p_engine, mgr, pd_cache);
65+
VCHECK_COMPILE_OPS(exec != nullptr, status::invalid_graph_op,
66+
"unimplemented op, can't compile op %s",
67+
op->get_name().c_str());
68+
if (cur_op->get_kind() == op_kind::dnnl_sdpa) {
69+
auto sdpa_exec = std::dynamic_pointer_cast<sdpa_executable_t>(exec);
70+
VCHECK_COMPILE_OPS(sdpa_exec->is_initialized(),
71+
status::unimplemented,
72+
"failed to create executable for op %s",
6873
op->get_name().c_str());
69-
70-
sg->execs_.emplace_back(exec);
71-
} catch (const std::runtime_error &e) {
72-
VCHECK_COMPILE_OPS(false, status::unimplemented,
73-
"failed to create executable for op %s: %s",
74-
op->get_name().c_str(), e.what());
7574
}
75+
sg->execs_.emplace_back(exec);
76+
7677
sg->is_constant_.push_back(op->has_attr(op_attr::is_constant)
7778
&& op->get_attr<bool>(op_attr::is_constant));
7879
return status::success;

0 commit comments

Comments
 (0)