Skip to content

Commit 2bfce18

Browse files
committed
graph: backend: dnnl: ukernel sdpa only supports f32 intermediates
1 parent 0269362 commit 2bfce18

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ status_t sdp_primitive_config_t::initial_check(
180180
graph::op_kind::Add, graph::op_kind::Select,
181181
graph::op_kind::SoftMax};
182182
op_ptr mm1 = nullptr, mm2 = nullptr, scale = nullptr;
183+
bool f32_inter = true;
183184
for (const auto &cur_op : sg->get_ops()) {
184185
const auto &op_kind = cur_op->get_kind();
185186
if (op_kind == graph::op_kind::DynamicDequantize
@@ -213,6 +214,10 @@ status_t sdp_primitive_config_t::initial_check(
213214
auto post_op = get_post_op(cur_op);
214215
if (post_op && mm1_post_op_kind.count(post_op->get_kind())) {
215216
mm1 = cur_op;
217+
const auto &lt_score
218+
= mm1->get_output_value(0)->get_logical_tensor();
219+
f32_inter = f32_inter
220+
&& (ltw(lt_score).data_type() == data_type::f32);
216221
// Not support select between mm1 and scale(optional)
217222
// GPT-J:[mm1] --> [select] --> [scale]* --> [mask]* --> ...
218223
VCHECK_SDP_PRIMITIVE(post_op->get_kind() != graph::op_kind::Select,
@@ -224,11 +229,20 @@ status_t sdp_primitive_config_t::initial_check(
224229
// Scale exists, update post_op and traverse to next op
225230
scale = post_op;
226231
post_op = get_post_op(post_op);
232+
const auto &lt_ss
233+
= scale->get_output_value(0)->get_logical_tensor();
234+
f32_inter = f32_inter
235+
&& (ltw(lt_ss).data_type() == data_type::f32);
227236
}
228237
// mask
229238
if (post_op) {
230239
if (post_op->get_kind() == graph::op_kind::Add) {
231240
// Mask exists, update post_op and traverse to next op
241+
const auto mask = post_op;
242+
const auto &lt_ms
243+
= mask->get_output_value(0)->get_logical_tensor();
244+
f32_inter = f32_inter
245+
&& (ltw(lt_ms).data_type() == data_type::f32);
232246
post_op = get_post_op(post_op);
233247
}
234248
// Not support select after scale(optional) and mask(optional)
@@ -245,6 +259,9 @@ status_t sdp_primitive_config_t::initial_check(
245259
}
246260
}
247261

262+
VCHECK_SDP_PRIMITIVE(f32_inter, status::invalid_graph,
263+
"only supports f32 intermediates.");
264+
248265
auto find_graph_inport = [&inputs](const std::shared_ptr<value_t> &val) {
249266
auto tmp_val = val;
250267
while (tmp_val->has_producer()) {

0 commit comments

Comments
 (0)