19
19
20
20
#include " common/compiler_workarounds.hpp"
21
21
22
+ #define VCHECK_SDP_PRIMITIVE (cond, status, msg, ...) \
23
+ VCONDCHECK (graph, create, check, sdp_primitive, (cond), status, msg, \
24
+ ##__VA_ARGS__);
25
+
22
26
namespace dnnl {
23
27
namespace impl {
24
28
namespace graph {
@@ -63,7 +67,8 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
63
67
if (post_op && mm1_post_op_kind.count (post_op->get_kind ())) {
64
68
// Locate mm1 and all post ops(scale and mask) here.
65
69
// 1. locate mm1
66
- if (mm1) return status::unimplemented;
70
+ VCHECK_SDP_PRIMITIVE (mm1 == nullptr , status::unimplemented,
71
+ " Multiple mm1 found" );
67
72
mm1 = cur_op;
68
73
// At least one of scale and mask exists
69
74
if (post_op->get_kind () == op_kind::dnnl_binary) {
@@ -84,15 +89,18 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
84
89
}
85
90
}
86
91
} else {
87
- if (mm2) return status::unimplemented;
92
+ VCHECK_SDP_PRIMITIVE (mm2 == nullptr , status::unimplemented,
93
+ " Multiple mm2 found" );
88
94
mm2 = cur_op;
89
95
}
90
96
}
91
97
92
98
// Locate input/outputs: Q, K, V, dst, scale, mask
93
99
mm1_ = mm1;
94
100
mm2_ = mm2;
95
- if (!mm1 || !mm2 || !final_op) return status::unimplemented;
101
+ VCHECK_SDP_PRIMITIVE ((mm1 && mm2 && final_op), status::unimplemented,
102
+ " Not all ops are found" );
103
+
96
104
q_ = mm1->get_input_value (0 );
97
105
k_ = mm1->get_input_value (1 );
98
106
v_ = mm2->get_input_value (1 );
@@ -136,7 +144,8 @@ status_t sdp_primitive_config_t::initial_check(
136
144
const std::shared_ptr<subgraph_t > &sg,
137
145
const std::vector<logical_tensor_t > &inputs) {
138
146
// At least 3 inputs: Q, K, V
139
- if (inputs.size () < 3 ) return status::invalid_arguments;
147
+ VCHECK_SDP_PRIMITIVE (inputs.size () >= 3 , status::invalid_arguments,
148
+ " At least 3 inputs are required" );
140
149
141
150
// step1(pattern check): Not support sdpa variants with select as mask
142
151
// We already have a pattern matcher to ensure that the sdpa patterns
@@ -175,9 +184,9 @@ status_t sdp_primitive_config_t::initial_check(
175
184
mm1 = cur_op;
176
185
// Not support select between mm1 and scale(optional)
177
186
// GPT-J:[mm1] --> [select] --> [scale]* --> [mask]* --> ...
178
- if (post_op->get_kind () == graph::op_kind::Select) {
179
- return status::unimplemented;
180
- }
187
+ VCHECK_SDP_PRIMITIVE (post_op->get_kind () != graph::op_kind::Select,
188
+ status::unimplemented,
189
+ " Not support select between mm1 and scale(optional) " );
181
190
// scale
182
191
if (post_op->get_kind () == graph::op_kind::Divide
183
192
|| post_op->get_kind () == graph::op_kind::Multiply) {
@@ -193,9 +202,10 @@ status_t sdp_primitive_config_t::initial_check(
193
202
194
203
// Not support select after scale(optional) and mask(optional)
195
204
// Distill-Bert:[mm1] --> [scale]* --> [mask]* --> [select] --> ...
196
- if (post_op->get_kind () == graph::op_kind::Select) {
197
- return status::unimplemented;
198
- }
205
+ VCHECK_SDP_PRIMITIVE (post_op->get_kind () != graph::op_kind::Select,
206
+ status::unimplemented,
207
+ " Not support select after scale(optional) and "
208
+ " mask(optional)" );
199
209
} else {
200
210
mm2 = cur_op;
201
211
}
@@ -214,27 +224,29 @@ status_t sdp_primitive_config_t::initial_check(
214
224
return -1 ;
215
225
};
216
226
217
- if (impl::utils::one_of (nullptr , mm1, mm2)) return status::invalid_graph;
227
+ VCHECK_SDP_PRIMITIVE (
228
+ mm1 && mm2, status::invalid_graph, " mm1 or mm2 is not found" );
218
229
219
230
// step3(dims check): only support 4-dims now.
220
231
int q_id = find_graph_inport (mm1->get_input_value (0 ));
221
232
int k_id = find_graph_inport (mm1->get_input_value (1 ));
222
233
int v_id = find_graph_inport (mm2->get_input_value (1 ));
223
234
224
- bool ok = true ;
225
- ok = ok && (q_id != - 1 ) && (k_id != - 1 ) && (v_id != - 1 );
226
- if (!ok) return status::unimplemented;
227
- ok = ok && ltw (inputs[q_id ]).vdims ().size () == 4
228
- && ltw (inputs[k_id ]).vdims ().size () == 4
229
- && ltw (inputs[v_id]). vdims (). size () == 4 ;
235
+ VCHECK_SDP_PRIMITIVE (q_id != - 1 && k_id != - 1 && v_id != - 1 ,
236
+ status::unimplemented, " Q, K, V are not found " );
237
+ VCHECK_SDP_PRIMITIVE ( ltw (inputs[q_id]). vdims (). size () == 4
238
+ && ltw (inputs[k_id ]).vdims ().size () == 4
239
+ && ltw (inputs[v_id ]).vdims ().size () == 4 ,
240
+ status::unimplemented, " Q, K, V should be 4-dims " ) ;
230
241
231
242
// sdp_primitive only supports single scale value.
232
243
if (scale) {
233
244
const auto &s = scale->get_input_value (1 )->get_logical_tensor ();
234
- if (ltw (s).nelems () != 1 ) return status::unimplemented;
245
+ VCHECK_SDP_PRIMITIVE (ltw (s).nelems () == 1 , status::unimplemented,
246
+ " Scale should be single value" );
235
247
}
236
248
237
- return ok ? status::success : status::unimplemented ;
249
+ return status::success;
238
250
}
239
251
240
252
status_t sdp_primitive_config_t::init (std::shared_ptr<subgraph_t > &sg,
@@ -281,14 +293,8 @@ status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,
281
293
282
294
auto status = sdpa_pd_->create_primitive (sdpa_prim_, p_engine.get ());
283
295
284
- if (status != status::success) {
285
- if (get_verbose (verbose_t ::create_dispatch, component_t ::graph)) {
286
- verbose_printf (
287
- " graph,create:dispatch,sdpa,could not create primitive, "
288
- " falling back\n " );
289
- }
290
- }
291
-
296
+ VCONDCHECK (graph, create, dispatch, sdp, status == status::success, status,
297
+ " could not create primitive, falling back\n " );
292
298
return status;
293
299
}
294
300
0 commit comments