Skip to content

Commit 9b8918d

Browse files
committed
graph: backend: kernels: sdp verbose enhancement
1 parent 7e450f8 commit 9b8918d

File tree

5 files changed

+66
-41
lines changed

5 files changed

+66
-41
lines changed

src/graph/backend/dnnl/kernels/sdp.hpp

+11
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030

3131
#include "graph/backend/dnnl/dnnl_partition_impl.hpp"
3232

33+
#define VDISPATCH_GRAPH_SDP(msg, ...) \
34+
VINFO(graph, create, dispatch, compile, msg, ##__VA_ARGS__)
35+
3336
namespace dnnl {
3437
namespace impl {
3538
namespace graph {
@@ -65,17 +68,25 @@ struct sdp_base_t : public kernel_base_t {
6568
if (enable_ukernel) {
6669
kernel = std::make_shared<sdp_primitive_kernel_t<quantized>>();
6770
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
71+
if (ret == status::success)
72+
VDISPATCH_GRAPH_SDP("dispatch to sdp_primitive_kernel");
6873
}
6974

7075
if (ret != status::success && enable_decomp) {
7176
kernel = std::make_shared<sdp_decomp_kernel_t<quantized, dt>>();
7277
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
78+
if (ret == status::success)
79+
VDISPATCH_GRAPH_SDP("dispatch to sdp_decomp_kernel");
7380
}
7481

7582
if (ret != status::success) {
7683
kernel = std::make_shared<larger_partition_kernel_t>();
7784
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
85+
if (ret == status::success)
86+
VDISPATCH_GRAPH_SDP("dispatch to larger_partition_kernel");
7887
}
88+
if (ret != status::success)
89+
VDISPATCH_GRAPH_SDP("fail to dispatch to sdp kernel");
7990
return ret;
8091
}
8192

src/graph/backend/dnnl/kernels/sdp_decomp.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ status_t sdp_decomp_kernel_t<quantized, dt>::compile_impl(
5151
BACKEND_DNNL_CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs));
5252

5353
// Check if it's supported by decomposition kernel
54-
if (!sdp_cfg_.initial_check(subgraph_, inputs))
55-
return status::unimplemented;
54+
VCONDCHECK(graph, create, check, sdp_decomp,
55+
sdp_cfg_.initial_check(subgraph_, inputs), status::unimplemented,
56+
"sdp_decomp_kernel_t: initial check failed");
5657

5758
subgraph_visualizer_t vis(part->id(), [this](const value_t *val) {
5859
return this->memory_planner_.get_memory_info(val);

src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp

+16-11
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
#include "graph/backend/dnnl/kernels/sdp_decomp_config.hpp"
1818

19+
#define VCHECK_SDP_DECOMP(cond, status, msg, ...) \
20+
VCONDCHECK(graph, create, check, sdp_decomp, (cond), status, msg, \
21+
##__VA_ARGS__);
22+
1923
namespace dnnl {
2024
namespace impl {
2125
namespace graph {
@@ -25,8 +29,8 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
2529
const std::vector<logical_tensor_t> &inputs) {
2630
// The order of input logical tensors in inputs is not certain, we need
2731
// to record the input offset in a certain order of ops.
28-
auto op_status = record_input_offset(sg, inputs);
29-
if (op_status != status::success) return false;
32+
VCHECK_SDP_DECOMP(record_input_offset(sg, inputs) == status::success, false,
33+
"Failed to record input offset");
3034
dims src1_user_dims = ltw(inputs[graph_inport[0]]).vdims();
3135
if (src1_user_dims.size() != 4) return false;
3236

@@ -41,9 +45,9 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
4145

4246
// Check batch size compatibility.
4347
dims wei2_user_dims = ltw(inputs[graph_inport[4]]).vdims();
44-
if (batch_size != wei1_user_dims[0] || batch_size != wei2_user_dims[0]) {
45-
return false;
46-
}
48+
VCHECK_SDP_DECOMP(
49+
batch_size == wei1_user_dims[0] && batch_size == wei2_user_dims[0],
50+
false, "Batch size mismatch");
4751

4852
// Check scale size
4953
if (graph_inport[2] != -1) {
@@ -451,9 +455,9 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
451455
// TODO(xxx): Currently, p2 is not supported by decomp kernel.
452456
// p1: [matmul] --> [scale] --> [select] --> [mask] --> ...
453457
// p2: [matmul] --> [select] --> [scale] --> [mask] --> ...
454-
if (post_op->get_kind() == graph::op_kind::Select) {
455-
return status::unimplemented;
456-
}
458+
VCHECK_SDP_DECOMP(post_op->get_kind() != graph::op_kind::Select,
459+
status::unimplemented,
460+
"Not support select between mm1 and scale");
457461
// find scale
458462
if (post_op->get_kind() == graph::op_kind::Divide
459463
|| post_op->get_kind() == graph::op_kind::Multiply) {
@@ -478,8 +482,8 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
478482
mm2 = cur_op;
479483
}
480484
}
481-
if (impl::utils::one_of(nullptr, mm1, mm2)) return status::invalid_graph;
482-
485+
VCHECK_SDP_DECOMP(mm1 != nullptr && mm2 != nullptr, status::invalid_graph,
486+
"Failed to find mm1 or mm2");
483487
int src1_id = find_graph_inport(mm1->get_input_value(0));
484488
graph_inport.emplace_back(src1_id);
485489
int wei1_id = find_graph_inport(mm1->get_input_value(1));
@@ -534,7 +538,8 @@ impl::status_t sdp_decomp_config_t::record_sdp_ops(
534538
auto post_op = get_post_op(cur_op);
535539
if (!post_op || post_op->get_kind() != op_kind::dnnl_softmax) continue;
536540
auto ppost_op = get_post_op(post_op);
537-
if (!ppost_op) return status::invalid_graph;
541+
VCHECK_SDP_DECOMP(ppost_op != nullptr, status::invalid_graph,
542+
"Failed to find post post op");
538543

539544
op_ptr reorder1;
540545
op_ptr reorder2;

src/graph/backend/dnnl/kernels/sdp_primitive.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,9 @@ status_t sdp_primitive_kernel_t<quantized>::get_prim_exec_args(
181181
&& res->find_value_mem_map(
182182
cfg_.v_zero_points_.get(), mem_storage[9]);
183183

184-
if (!ok) return status::runtime_error;
184+
VCONDCHECK(graph, exec, check, sdp_primitive_kernel, ok,
185+
status::runtime_error,
186+
"sdp_primitive_kernel get_prim_exec_args failed");
185187

186188
memory_arg_t mem_arg_q = {mem_storage[0].get(), true};
187189
memory_arg_t mem_arg_k = {mem_storage[1].get(), true};

src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp

+33-27
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
#include "common/compiler_workarounds.hpp"
2121

22+
#define VCHECK_SDP_PRIMITIVE(cond, status, msg, ...) \
23+
VCONDCHECK(graph, create, check, sdp_primitive, (cond), status, msg, \
24+
##__VA_ARGS__);
25+
2226
namespace dnnl {
2327
namespace impl {
2428
namespace graph {
@@ -63,7 +67,8 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
6367
if (post_op && mm1_post_op_kind.count(post_op->get_kind())) {
6468
// Locate mm1 and all post ops(scale and mask) here.
6569
// 1. locate mm1
66-
if (mm1) return status::unimplemented;
70+
VCHECK_SDP_PRIMITIVE(mm1 == nullptr, status::unimplemented,
71+
"Multiple mm1 found");
6772
mm1 = cur_op;
6873
// At least one of scale and mask exists
6974
if (post_op->get_kind() == op_kind::dnnl_binary) {
@@ -84,15 +89,18 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
8489
}
8590
}
8691
} else {
87-
if (mm2) return status::unimplemented;
92+
VCHECK_SDP_PRIMITIVE(mm2 == nullptr, status::unimplemented,
93+
"Multiple mm2 found");
8894
mm2 = cur_op;
8995
}
9096
}
9197

9298
// Locate input/outputs: Q, K, V, dst, scale, mask
9399
mm1_ = mm1;
94100
mm2_ = mm2;
95-
if (!mm1 || !mm2 || !final_op) return status::unimplemented;
101+
VCHECK_SDP_PRIMITIVE((mm1 && mm2 && final_op), status::unimplemented,
102+
"Not all ops are found");
103+
96104
q_ = mm1->get_input_value(0);
97105
k_ = mm1->get_input_value(1);
98106
v_ = mm2->get_input_value(1);
@@ -136,7 +144,8 @@ status_t sdp_primitive_config_t::initial_check(
136144
const std::shared_ptr<subgraph_t> &sg,
137145
const std::vector<logical_tensor_t> &inputs) {
138146
// At least 3 inputs: Q, K, V
139-
if (inputs.size() < 3) return status::invalid_arguments;
147+
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
148+
"At least 3 inputs are required");
140149

141150
// step1(pattern check): Not support sdpa variants with select as mask
142151
// We already have a pattern matcher to ensure that the sdpa patterns
@@ -175,9 +184,9 @@ status_t sdp_primitive_config_t::initial_check(
175184
mm1 = cur_op;
176185
// Not support select between mm1 and scale(optional)
177186
// GPT-J:[mm1] --> [select] --> [scale]* --> [mask]* --> ...
178-
if (post_op->get_kind() == graph::op_kind::Select) {
179-
return status::unimplemented;
180-
}
187+
VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select,
188+
status::unimplemented,
189+
"Not support select between mm1 and scale(optional)");
181190
// scale
182191
if (post_op->get_kind() == graph::op_kind::Divide
183192
|| post_op->get_kind() == graph::op_kind::Multiply) {
@@ -193,9 +202,10 @@ status_t sdp_primitive_config_t::initial_check(
193202

194203
// Not support select after scale(optional) and mask(optional)
195204
// Distill-Bert:[mm1] --> [scale]* --> [mask]* --> [select] --> ...
196-
if (post_op->get_kind() == graph::op_kind::Select) {
197-
return status::unimplemented;
198-
}
205+
VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select,
206+
status::unimplemented,
207+
"Not support select after scale(optional) and "
208+
"mask(optional)");
199209
} else {
200210
mm2 = cur_op;
201211
}
@@ -214,27 +224,29 @@ status_t sdp_primitive_config_t::initial_check(
214224
return -1;
215225
};
216226

217-
if (impl::utils::one_of(nullptr, mm1, mm2)) return status::invalid_graph;
227+
VCHECK_SDP_PRIMITIVE(
228+
mm1 && mm2, status::invalid_graph, "mm1 or mm2 is not found");
218229

219230
// step3(dims check): only support 4-dims now.
220231
int q_id = find_graph_inport(mm1->get_input_value(0));
221232
int k_id = find_graph_inport(mm1->get_input_value(1));
222233
int v_id = find_graph_inport(mm2->get_input_value(1));
223234

224-
bool ok = true;
225-
ok = ok && (q_id != -1) && (k_id != -1) && (v_id != -1);
226-
if (!ok) return status::unimplemented;
227-
ok = ok && ltw(inputs[q_id]).vdims().size() == 4
228-
&& ltw(inputs[k_id]).vdims().size() == 4
229-
&& ltw(inputs[v_id]).vdims().size() == 4;
235+
VCHECK_SDP_PRIMITIVE(q_id != -1 && k_id != -1 && v_id != -1,
236+
status::unimplemented, "Q, K, V are not found");
237+
VCHECK_SDP_PRIMITIVE(ltw(inputs[q_id]).vdims().size() == 4
238+
&& ltw(inputs[k_id]).vdims().size() == 4
239+
&& ltw(inputs[v_id]).vdims().size() == 4,
240+
status::unimplemented, "Q, K, V should be 4-dims");
230241

231242
// sdp_primitive only supports single scale value.
232243
if (scale) {
233244
const auto &s = scale->get_input_value(1)->get_logical_tensor();
234-
if (ltw(s).nelems() != 1) return status::unimplemented;
245+
VCHECK_SDP_PRIMITIVE(ltw(s).nelems() == 1, status::unimplemented,
246+
"Scale should be single value");
235247
}
236248

237-
return ok ? status::success : status::unimplemented;
249+
return status::success;
238250
}
239251

240252
status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,
@@ -281,14 +293,8 @@ status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,
281293

282294
auto status = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get());
283295

284-
if (status != status::success) {
285-
if (get_verbose(verbose_t::create_dispatch, component_t::graph)) {
286-
verbose_printf(
287-
"graph,create:dispatch,sdpa,could not create primitive, "
288-
"falling back\n");
289-
}
290-
}
291-
296+
VCONDCHECK(graph, create, dispatch, sdp, status == status::success, status,
297+
"could not create primitive, falling back\n");
292298
return status;
293299
}
294300

0 commit comments

Comments
 (0)