16
16
17
17
#include " graph/backend/dnnl/kernels/sdp_decomp_config.hpp"
18
18
19
+ #define VCHECK_SDP_DECOMP (cond, status, msg, ...) \
20
+ VCONDCHECK (graph, create, check, sdp_decomp_kernel_t , (cond), status, msg, \
21
+ ##__VA_ARGS__);
22
+
19
23
namespace dnnl {
20
24
namespace impl {
21
25
namespace graph {
@@ -25,10 +29,10 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
25
29
const std::vector<logical_tensor_t > &inputs) {
26
30
// The order of input logical tensors in inputs is not certain, we need
27
31
// to record the input offset in a certain order of ops.
28
- auto op_status = record_input_offset (sg, inputs);
29
- if (op_status != status::success) return false ;
32
+ CHECK_BOOL (record_input_offset (sg, inputs));
30
33
dims src1_user_dims = ltw (inputs[graph_inport[0 ]]).vdims ();
31
- if (src1_user_dims.size () != 4 ) return false ;
34
+ VCHECK_SDP_DECOMP (src1_user_dims.size () == 4 , false ,
35
+ " Input dims should be 4, but got %zu" , src1_user_dims.size ());
32
36
33
37
// Initialize SDP input dimension according to the src of mm1
34
38
batch_size = src1_user_dims[0 ];
@@ -41,14 +45,17 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
41
45
42
46
// Check batch size compatibility.
43
47
dims wei2_user_dims = ltw (inputs[graph_inport[4 ]]).vdims ();
44
- if (batch_size != wei1_user_dims[0 ] || batch_size != wei2_user_dims[0 ]) {
45
- return false ;
46
- }
48
+ VCHECK_SDP_DECOMP (
49
+ batch_size == wei1_user_dims[0 ] && batch_size == wei2_user_dims[0 ],
50
+ false ,
51
+ " Batch size mismatch, batch_size: %lld, wei1: %lld, wei2: %lld" ,
52
+ batch_size, wei1_user_dims[0 ], wei2_user_dims[0 ]);
47
53
48
54
// Check scale size
49
55
if (graph_inport[2 ] != -1 ) {
50
56
auto scale_sz = ltw (inputs[graph_inport[2 ]]).nelems ();
51
- if (scale_sz != 1 ) return false ;
57
+ VCHECK_SDP_DECOMP (scale_sz == 1 , false ,
58
+ " Only supports single scale value, but got %lld" , scale_sz);
52
59
}
53
60
54
61
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP
@@ -65,10 +72,13 @@ bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
65
72
#define RATIO 2
66
73
// Initialize nthr with current threads num
67
74
nthr = dnnl_get_current_num_threads ();
68
- return batch_size * num_head_q > RATIO * nthr;
69
- #else
70
- return true ;
75
+ VCHECK_SDP_DECOMP (batch_size * num_head_q > RATIO * nthr, false ,
76
+ " Doesn't meet condition for decompose: Batch size * num_head_q "
77
+ " should be larger than ratio * nthr, but got batch_size %lld, "
78
+ " num_head_q %lld, ration %d , nthr %d" ,
79
+ batch_size, num_head_q, RATIO, nthr);
71
80
#endif
81
+ return true ;
72
82
}
73
83
74
84
template <bool quantized, memory::data_type dt>
@@ -78,7 +88,7 @@ impl::status_t sdp_decomp_config_t::construct_params(
78
88
const std::vector<logical_tensor_t > &inputs) {
79
89
80
90
// Record the ops inside of SDP pattern for later usage
81
- record_sdp_ops (sg, quantized);
91
+ CHECK ( record_sdp_ops (sg, quantized) );
82
92
83
93
// Update SDPA input params. Sequence length for query and key/value are
84
94
// NOT always same.
@@ -435,11 +445,12 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
435
445
graph::op_kind::SoftMax};
436
446
for (const auto &cur_op : sg->get_ops ()) {
437
447
const auto &op_kind = cur_op->get_kind ();
438
- if (op_kind == graph::op_kind::DynamicDequantize
439
- && cur_op->get_attr <std::string>(op_attr::qtype)
440
- == " per_group" ) {
441
- return status::unimplemented;
442
- }
448
+ VCHECK_SDP_DECOMP (
449
+ !(op_kind == graph::op_kind::DynamicDequantize
450
+ && cur_op->get_attr <std::string>(op_attr::qtype)
451
+ == " per_group" ),
452
+ status::unimplemented,
453
+ " Not support per_group DynamicDequantize" );
443
454
// both mm1 and mm2 are found.
444
455
if (mm1 && mm2) break ;
445
456
if (op_kind != graph::op_kind::MatMul) continue ;
@@ -451,9 +462,9 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
451
462
// TODO(xxx): Currently, p2 is not supported by decomp kernel.
452
463
// p1: [matmul] --> [scale] --> [select] --> [mask] --> ...
453
464
// p2: [matmul] --> [select] --> [scale] --> [mask] --> ...
454
- if (post_op->get_kind () == graph::op_kind::Select) {
455
- return status::unimplemented;
456
- }
465
+ VCHECK_SDP_DECOMP (post_op->get_kind () != graph::op_kind::Select,
466
+ status::unimplemented,
467
+ " Not support select between matmul1 and scale " );
457
468
// find scale
458
469
if (post_op->get_kind () == graph::op_kind::Divide
459
470
|| post_op->get_kind () == graph::op_kind::Multiply) {
@@ -478,8 +489,8 @@ impl::status_t sdp_decomp_config_t::record_input_offset(
478
489
mm2 = cur_op;
479
490
}
480
491
}
481
- if ( impl::utils::one_of ( nullptr , mm1, mm2)) return status::invalid_graph;
482
-
492
+ VCHECK_SDP_DECOMP (mm1 != nullptr && mm2 != nullptr , status::invalid_graph,
493
+ " Failed to find matmul1 or matmul2 " );
483
494
int src1_id = find_graph_inport (mm1->get_input_value (0 ));
484
495
graph_inport.emplace_back (src1_id);
485
496
int wei1_id = find_graph_inport (mm1->get_input_value (1 ));
@@ -534,7 +545,8 @@ impl::status_t sdp_decomp_config_t::record_sdp_ops(
534
545
auto post_op = get_post_op (cur_op);
535
546
if (!post_op || post_op->get_kind () != op_kind::dnnl_softmax) continue ;
536
547
auto ppost_op = get_post_op (post_op);
537
- if (!ppost_op) return status::invalid_graph;
548
+ VCHECK_SDP_DECOMP (ppost_op != nullptr , status::invalid_graph,
549
+ " Failed to find post post op for matmul" );
538
550
539
551
op_ptr reorder1;
540
552
op_ptr reorder2;
0 commit comments