Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit eb3d730

Browse files
committedDec 25, 2024·
graph: backend: kernels: sdp verbose enhancement
1 parent 22fdb3a commit eb3d730

File tree

4 files changed

+77
-50
lines changed

4 files changed

+77
-50
lines changed
 

‎src/graph/backend/dnnl/kernels/sdp.hpp

+8
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030

3131
#include "graph/backend/dnnl/dnnl_partition_impl.hpp"
3232

33+
#define VDISPATCH_GRAPH_SDP(msg, ...) \
34+
VINFO(graph, create, dispatch, compile, msg, ##__VA_ARGS__)
35+
3336
namespace dnnl {
3437
namespace impl {
3538
namespace graph {
@@ -76,6 +79,11 @@ struct sdp_base_t : public kernel_base_t {
7679
kernel = std::make_shared<larger_partition_kernel_t>();
7780
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
7881
}
82+
if (ret == status::success)
83+
VDISPATCH_GRAPH_SDP(
84+
"sdpa is dispatched to (%s)", kernel->str().c_str());
85+
else
86+
VDISPATCH_GRAPH_SDP("sdpa is failed to dispatch");
7987
return ret;
8088
}
8189

‎src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp

+33-22
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
#include "graph/backend/dnnl/kernels/sdp_decomp_config.hpp"
1818

19+
#define VCHECK_SDP_DECOMP(cond, status, msg, ...) \
20+
VCONDCHECK(graph, create, check, sdp_decomp_kernel_t, (cond), status, msg, \
21+
##__VA_ARGS__);
22+
1923
namespace dnnl {
2024
namespace impl {
2125
namespace graph {
@@ -25,10 +29,10 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
2529
const std::vector<logical_tensor_t> &inputs) {
2630
// The order of input logical tensors in inputs is not certain, we need
2731
// to record the input offset in a certain order of ops.
28-
auto op_status = record_input_offset(sg, inputs);
29-
if (op_status != status::success) return false;
32+
CHECK_BOOL(record_input_offset(sg, inputs));
3033
dims src1_user_dims = ltw(inputs[graph_inport[0]]).vdims();
31-
if (src1_user_dims.size() != 4) return false;
34+
VCHECK_SDP_DECOMP(src1_user_dims.size() == 4, false,
35+
"Input dims should be 4, but got %zu", src1_user_dims.size());
3236

3337
// Initialize SDP input dimension according to the src of mm1
3438
batch_size = src1_user_dims[0];
@@ -41,14 +45,17 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
4145

4246
// Check batch size compatibility.
4347
dims wei2_user_dims = ltw(inputs[graph_inport[4]]).vdims();
44-
if (batch_size != wei1_user_dims[0] || batch_size != wei2_user_dims[0]) {
45-
return false;
46-
}
48+
VCHECK_SDP_DECOMP(
49+
batch_size == wei1_user_dims[0] && batch_size == wei2_user_dims[0],
50+
false,
51+
"Batch size mismatch, batch_size: %lld, wei1: %lld, wei2: %lld",
52+
batch_size, wei1_user_dims[0], wei2_user_dims[0]);
4753

4854
// Check scale size
4955
if (graph_inport[2] != -1) {
5056
auto scale_sz = ltw(inputs[graph_inport[2]]).nelems();
51-
if (scale_sz != 1) return false;
57+
VCHECK_SDP_DECOMP(scale_sz == 1, false,
58+
"Only supports single scale value, but got %lld", scale_sz);
5259
}
5360

5461
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP
@@ -65,10 +72,12 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
6572
#define RATIO 2
6673
// Initialize nthr with current threads num
6774
nthr = dnnl_get_current_num_threads();
68-
return batch_size * num_head_q > RATIO * nthr;
69-
#else
70-
return true;
75+
VCHECK_SDP_DECOMP(batch_size * num_head_q > RATIO * nthr, false,
76+
"Doesn't meet conditions for decompose: Batch size * num_head_q "
77+
"should be larger than %d * nthr",
78+
RATIO);
7179
#endif
80+
return true;
7281
}
7382

7483
template <bool quantized, memory::data_type dt>
@@ -78,7 +87,7 @@ impl::status_t sdp_decomp_config_t::construct_params(
7887
const std::vector<logical_tensor_t> &inputs) {
7988

8089
// Record the ops inside of SDP pattern for later usage
81-
record_sdp_ops(sg, quantized);
90+
CHECK(record_sdp_ops(sg, quantized));
8291

8392
// Update SDPA input params. Sequence length for query and key/value are
8493
// NOT always same.
@@ -435,11 +444,12 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
435444
graph::op_kind::SoftMax};
436445
for (const auto &cur_op : sg->get_ops()) {
437446
const auto &op_kind = cur_op->get_kind();
438-
if (op_kind == graph::op_kind::DynamicDequantize
439-
&& cur_op->get_attr<std::string>(op_attr::qtype)
440-
== "per_group") {
441-
return status::unimplemented;
442-
}
447+
VCHECK_SDP_DECOMP(
448+
!(op_kind == graph::op_kind::DynamicDequantize
449+
&& cur_op->get_attr<std::string>(op_attr::qtype)
450+
== "per_group"),
451+
status::unimplemented,
452+
"Not support per_group DynamicDequantize");
443453
// both mm1 and mm2 are found.
444454
if (mm1 && mm2) break;
445455
if (op_kind != graph::op_kind::MatMul) continue;
@@ -451,9 +461,9 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
451461
// TODO(xxx): Currently, p2 is not supported by decomp kernel.
452462
// p1: [matmul] --> [scale] --> [select] --> [mask] --> ...
453463
// p2: [matmul] --> [select] --> [scale] --> [mask] --> ...
454-
if (post_op->get_kind() == graph::op_kind::Select) {
455-
return status::unimplemented;
456-
}
464+
VCHECK_SDP_DECOMP(post_op->get_kind() != graph::op_kind::Select,
465+
status::unimplemented,
466+
"Not support select between matmul1 and scale");
457467
// find scale
458468
if (post_op->get_kind() == graph::op_kind::Divide
459469
|| post_op->get_kind() == graph::op_kind::Multiply) {
@@ -478,8 +488,8 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
478488
mm2 = cur_op;
479489
}
480490
}
481-
if (impl::utils::one_of(nullptr, mm1, mm2)) return status::invalid_graph;
482-
491+
VCHECK_SDP_DECOMP(mm1 != nullptr && mm2 != nullptr, status::invalid_graph,
492+
"Failed to find matmul1 or matmul2");
483493
int src1_id = find_graph_inport(mm1->get_input_value(0));
484494
graph_inport.emplace_back(src1_id);
485495
int wei1_id = find_graph_inport(mm1->get_input_value(1));
@@ -534,7 +544,8 @@ impl::status_t sdp_decomp_config_t::record_sdp_ops(
534544
auto post_op = get_post_op(cur_op);
535545
if (!post_op || post_op->get_kind() != op_kind::dnnl_softmax) continue;
536546
auto ppost_op = get_post_op(post_op);
537-
if (!ppost_op) return status::invalid_graph;
547+
VCHECK_SDP_DECOMP(ppost_op != nullptr, status::invalid_graph,
548+
"Failed to find post post op for matmul");
538549

539550
op_ptr reorder1;
540551
op_ptr reorder2;

‎src/graph/backend/dnnl/kernels/sdp_primitive.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,9 @@ status_t sdp_primitive_kernel_t<quantized>::get_prim_exec_args(
181181
&& res->find_value_mem_map(
182182
cfg_.v_zero_points_.get(), mem_storage[9]);
183183

184-
if (!ok) return status::runtime_error;
184+
VCONDCHECK(graph, exec, check, sdp_primitive_kernel, ok,
185+
status::runtime_error,
186+
"sdp_primitive_kernel get_prim_exec_args failed");
185187

186188
memory_arg_t mem_arg_q = {mem_storage[0].get(), true};
187189
memory_arg_t mem_arg_k = {mem_storage[1].get(), true};

‎src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp

+33-27
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
#include "common/compiler_workarounds.hpp"
2121

22+
#define VCHECK_SDP_PRIMITIVE(cond, status, msg, ...) \
23+
VCONDCHECK(graph, create, check, sdp_primitive_kernel_t, (cond), status, \
24+
msg, ##__VA_ARGS__);
25+
2226
namespace dnnl {
2327
namespace impl {
2428
namespace graph {
@@ -63,7 +67,8 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
6367
if (post_op && mm1_post_op_kind.count(post_op->get_kind())) {
6468
// Locate mm1 and all post ops(scale and mask) here.
6569
// 1. locate mm1
66-
if (mm1) return status::unimplemented;
70+
VCHECK_SDP_PRIMITIVE(mm1 == nullptr, status::unimplemented,
71+
"Multiple mm1 found");
6772
mm1 = cur_op;
6873
// At least one of scale and mask exists
6974
if (post_op->get_kind() == op_kind::dnnl_binary) {
@@ -84,15 +89,18 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
8489
}
8590
}
8691
} else {
87-
if (mm2) return status::unimplemented;
92+
VCHECK_SDP_PRIMITIVE(mm2 == nullptr, status::unimplemented,
93+
"Multiple mm2 found");
8894
mm2 = cur_op;
8995
}
9096
}
9197

9298
// Locate input/outputs: Q, K, V, dst, scale, mask
9399
mm1_ = mm1;
94100
mm2_ = mm2;
95-
if (!mm1 || !mm2 || !final_op) return status::unimplemented;
101+
VCHECK_SDP_PRIMITIVE((mm1 && mm2 && final_op), status::unimplemented,
102+
"Not all ops are found");
103+
96104
q_ = mm1->get_input_value(0);
97105
k_ = mm1->get_input_value(1);
98106
v_ = mm2->get_input_value(1);
@@ -136,7 +144,8 @@ status_t sdp_primitive_config_t::initial_check(
136144
const std::shared_ptr<subgraph_t> &sg,
137145
const std::vector<logical_tensor_t> &inputs) {
138146
// At least 3 inputs: Q, K, V
139-
if (inputs.size() < 3) return status::invalid_arguments;
147+
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
148+
"At least 3 inputs are required");
140149

141150
// step1(pattern check): Not support sdpa variants with select as mask
142151
// We already have a pattern matcher to ensure that the sdpa patterns
@@ -175,9 +184,9 @@ status_t sdp_primitive_config_t::initial_check(
175184
mm1 = cur_op;
176185
// Not support select between mm1 and scale(optional)
177186
// GPT-J:[mm1] --> [select] --> [scale]* --> [mask]* --> ...
178-
if (post_op->get_kind() == graph::op_kind::Select) {
179-
return status::unimplemented;
180-
}
187+
VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select,
188+
status::unimplemented,
189+
"Not support select between mm1 and scale(optional)");
181190
// scale
182191
if (post_op->get_kind() == graph::op_kind::Divide
183192
|| post_op->get_kind() == graph::op_kind::Multiply) {
@@ -193,9 +202,10 @@ status_t sdp_primitive_config_t::initial_check(
193202

194203
// Not support select after scale(optional) and mask(optional)
195204
// Distill-Bert:[mm1] --> [scale]* --> [mask]* --> [select] --> ...
196-
if (post_op->get_kind() == graph::op_kind::Select) {
197-
return status::unimplemented;
198-
}
205+
VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select,
206+
status::unimplemented,
207+
"Not support select after scale(optional) and "
208+
"mask(optional)");
199209
} else {
200210
mm2 = cur_op;
201211
}
@@ -214,27 +224,29 @@ status_t sdp_primitive_config_t::initial_check(
214224
return -1;
215225
};
216226

217-
if (impl::utils::one_of(nullptr, mm1, mm2)) return status::invalid_graph;
227+
VCHECK_SDP_PRIMITIVE(
228+
mm1 && mm2, status::invalid_graph, "mm1 or mm2 is not found");
218229

219230
// step3(dims check): only support 4-dims now.
220231
int q_id = find_graph_inport(mm1->get_input_value(0));
221232
int k_id = find_graph_inport(mm1->get_input_value(1));
222233
int v_id = find_graph_inport(mm2->get_input_value(1));
223234

224-
bool ok = true;
225-
ok = ok && (q_id != -1) && (k_id != -1) && (v_id != -1);
226-
if (!ok) return status::unimplemented;
227-
ok = ok && ltw(inputs[q_id]).vdims().size() == 4
228-
&& ltw(inputs[k_id]).vdims().size() == 4
229-
&& ltw(inputs[v_id]).vdims().size() == 4;
235+
VCHECK_SDP_PRIMITIVE(q_id != -1 && k_id != -1 && v_id != -1,
236+
status::unimplemented, "Q, K, V are not found");
237+
VCHECK_SDP_PRIMITIVE(ltw(inputs[q_id]).vdims().size() == 4
238+
&& ltw(inputs[k_id]).vdims().size() == 4
239+
&& ltw(inputs[v_id]).vdims().size() == 4,
240+
status::unimplemented, "Q, K, V should be 4-dims");
230241

231242
// sdp_primitive only supports single scale value.
232243
if (scale) {
233244
const auto &s = scale->get_input_value(1)->get_logical_tensor();
234-
if (ltw(s).nelems() != 1) return status::unimplemented;
245+
VCHECK_SDP_PRIMITIVE(ltw(s).nelems() == 1, status::unimplemented,
246+
"Scale should be single value");
235247
}
236248

237-
return ok ? status::success : status::unimplemented;
249+
return status::success;
238250
}
239251

240252
status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,
@@ -281,14 +293,8 @@ status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,
281293

282294
auto status = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get());
283295

284-
if (status != status::success) {
285-
if (get_verbose(verbose_t::create_dispatch, component_t::graph)) {
286-
verbose_printf(
287-
"graph,create:dispatch,sdpa,could not create primitive, "
288-
"falling back\n");
289-
}
290-
}
291-
296+
VCONDCHECK(graph, create, dispatch, sdp, status == status::success, status,
297+
"could not create sdp primitive, falling back\n");
292298
return status;
293299
}
294300

0 commit comments

Comments
 (0)
Please sign in to comment.