Skip to content

Commit ddca382

Browse files
committed
graph: dnnl: refine check and layout propagation
1 parent 9252ebc commit ddca382

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,15 @@ status_t sdp_primitive_config_t::initial_check(
286286

287287
VCHECK_SDP_PRIMITIVE(q_id != -1 && k_id != -1 && v_id != -1,
288288
status::unimplemented, "Q, K, V are not found");
289-
VCHECK_SDP_PRIMITIVE(ltw(inputs[q_id]).vdims().size() == 4
290-
&& ltw(inputs[k_id]).vdims().size() == 4
291-
&& ltw(inputs[v_id]).vdims().size() == 4,
292-
status::unimplemented, "Q, K, V should be 4-dims");
289+
290+
// Note: sdpa_primitive_v1 kenrel accept 5D GQA pattern, and will reshape to
291+
// 4D in later compilation pass.
292+
if (!v1_kenrel) {
293+
VCHECK_SDP_PRIMITIVE(ltw(inputs[q_id]).vdims().size() == 4
294+
&& ltw(inputs[k_id]).vdims().size() == 4
295+
&& ltw(inputs[v_id]).vdims().size() == 4,
296+
status::unimplemented, "Q, K, V should be 4-dims");
297+
}
293298

294299
// sdp_primitive only supports single scale value.
295300
if (scale) {

src/graph/backend/dnnl/layout_propagator.cpp

+30-5
Original file line numberDiff line numberDiff line change
@@ -1578,13 +1578,38 @@ status_t layout_propagator_for_sdpa(std::shared_ptr<op_t> &op,
15781578

15791579
value_ptr dst_val = op->get_output_value(0);
15801580
const logical_tensor_t &out_lt = dst_val->get_logical_tensor();
1581-
15821581
dnnl::memory::desc expected_md;
1583-
// Set default output layout format for sdpa as acbd
1582+
15841583
if (ltw(out_lt).is_any()) {
1585-
expected_md = {ltw(out_lt).vdims(),
1586-
static_cast<dnnl::memory::data_type>(ltw(out_lt).data_type()),
1587-
dnnl::memory::format_tag::acbd};
1584+
// For GQA, we need to check the layout of the dnnl_reshape output
1585+
// following dnnl_sdpa, which is given by the user.
1586+
if (!dst_val->get_consumers().empty()) {
1587+
const auto &consumer_op = dst_val->get_consumers()[0].get_op();
1588+
const logical_tensor_t &consumer_out
1589+
= consumer_op.get_output_value(0)->get_logical_tensor();
1590+
if (consumer_op.get_kind() == op_kind::dnnl_reshape
1591+
&& ltw(consumer_out).ndims() == 5
1592+
&& ltw(consumer_out).is_strided()) {
1593+
const auto &ori_strides = ltw(consumer_out).vstrides();
1594+
std::vector<dim_t> strides = {ori_strides[0], ori_strides[2],
1595+
ori_strides[3], ori_strides[4]};
1596+
dnnl::memory::desc tmp_md {ltw(out_lt).vdims(),
1597+
static_cast<dnnl::memory::data_type>(
1598+
ltw(out_lt).data_type()),
1599+
strides};
1600+
expected_md = tmp_md;
1601+
} else {
1602+
expected_md = {ltw(out_lt).vdims(),
1603+
static_cast<dnnl::memory::data_type>(
1604+
ltw(out_lt).data_type()),
1605+
dnnl::memory::format_tag::acbd};
1606+
}
1607+
} else {
1608+
expected_md = {ltw(out_lt).vdims(),
1609+
static_cast<dnnl::memory::data_type>(
1610+
ltw(out_lt).data_type()),
1611+
dnnl::memory::format_tag::acbd};
1612+
}
15881613
} else {
15891614
expected_md = make_dnnl_memory_desc(out_lt);
15901615
}

0 commit comments

Comments
 (0)