@@ -180,6 +180,7 @@ status_t sdp_primitive_config_t::initial_check(
180
180
graph::op_kind::Add, graph::op_kind::Select,
181
181
graph::op_kind::SoftMax};
182
182
op_ptr mm1 = nullptr , mm2 = nullptr , scale = nullptr ;
183
+ bool f32_inter = true ;
183
184
for (const auto &cur_op : sg->get_ops ()) {
184
185
const auto &op_kind = cur_op->get_kind ();
185
186
if (op_kind == graph::op_kind::DynamicDequantize
@@ -213,6 +214,10 @@ status_t sdp_primitive_config_t::initial_check(
213
214
auto post_op = get_post_op (cur_op);
214
215
if (post_op && mm1_post_op_kind.count (post_op->get_kind ())) {
215
216
mm1 = cur_op;
217
+ const auto <_score
218
+ = mm1->get_output_value (0 )->get_logical_tensor ();
219
+ f32_inter = f32_inter
220
+ && (ltw (lt_score).data_type () == data_type::f32);
216
221
// Not support select between mm1 and scale(optional)
217
222
// GPT-J:[mm1] --> [select] --> [scale]* --> [mask]* --> ...
218
223
VCHECK_SDP_PRIMITIVE (post_op->get_kind () != graph::op_kind::Select,
@@ -224,11 +229,20 @@ status_t sdp_primitive_config_t::initial_check(
224
229
// Scale exists, update post_op and traverse to next op
225
230
scale = post_op;
226
231
post_op = get_post_op (post_op);
232
+ const auto <_ss
233
+ = scale->get_output_value (0 )->get_logical_tensor ();
234
+ f32_inter = f32_inter
235
+ && (ltw (lt_ss).data_type () == data_type::f32);
227
236
}
228
237
// mask
229
238
if (post_op) {
230
239
if (post_op->get_kind () == graph::op_kind::Add) {
231
240
// Mask exists, update post_op and traverse to next op
241
+ const auto mask = post_op;
242
+ const auto <_ms
243
+ = mask->get_output_value (0 )->get_logical_tensor ();
244
+ f32_inter = f32_inter
245
+ && (ltw (lt_ms).data_type () == data_type::f32);
232
246
post_op = get_post_op (post_op);
233
247
}
234
248
// Not support select after scale(optional) and mask(optional)
@@ -245,6 +259,9 @@ status_t sdp_primitive_config_t::initial_check(
245
259
}
246
260
}
247
261
262
+ VCHECK_SDP_PRIMITIVE (f32_inter, status::invalid_graph,
263
+ " only supports f32 intermediates." );
264
+
248
265
auto find_graph_inport = [&inputs](const std::shared_ptr<value_t > &val) {
249
266
auto tmp_val = val;
250
267
while (tmp_val->has_producer ()) {
0 commit comments