@@ -1578,13 +1578,38 @@ status_t layout_propagator_for_sdpa(std::shared_ptr<op_t> &op,
1578
1578
1579
1579
value_ptr dst_val = op->get_output_value (0 );
1580
1580
const logical_tensor_t &out_lt = dst_val->get_logical_tensor ();
1581
-
1582
1581
dnnl::memory::desc expected_md;
1583
- // Set default output layout format for sdpa as acbd
1582
+
1584
1583
if (ltw (out_lt).is_any ()) {
1585
- expected_md = {ltw (out_lt).vdims (),
1586
- static_cast <dnnl::memory::data_type>(ltw (out_lt).data_type ()),
1587
- dnnl::memory::format_tag::acbd};
1584
+ // For GQA, we need to check the layout of the dnnl_reshape output
1585
+ // following dnnl_sdpa, which is given by the user.
1586
+ if (!dst_val->get_consumers ().empty ()) {
1587
+ const auto &consumer_op = dst_val->get_consumers ()[0 ].get_op ();
1588
+ const logical_tensor_t &consumer_out
1589
+ = consumer_op.get_output_value (0 )->get_logical_tensor ();
1590
+ if (consumer_op.get_kind () == op_kind::dnnl_reshape
1591
+ && ltw (consumer_out).ndims () == 5
1592
+ && ltw (consumer_out).is_strided ()) {
1593
+ const auto &ori_strides = ltw (consumer_out).vstrides ();
1594
+ std::vector<dim_t > strides = {ori_strides[0 ], ori_strides[2 ],
1595
+ ori_strides[3 ], ori_strides[4 ]};
1596
+ dnnl::memory::desc tmp_md {ltw (out_lt).vdims (),
1597
+ static_cast <dnnl::memory::data_type>(
1598
+ ltw (out_lt).data_type ()),
1599
+ strides};
1600
+ expected_md = tmp_md;
1601
+ } else {
1602
+ expected_md = {ltw (out_lt).vdims (),
1603
+ static_cast <dnnl::memory::data_type>(
1604
+ ltw (out_lt).data_type ()),
1605
+ dnnl::memory::format_tag::acbd};
1606
+ }
1607
+ } else {
1608
+ expected_md = {ltw (out_lt).vdims (),
1609
+ static_cast <dnnl::memory::data_type>(
1610
+ ltw (out_lt).data_type ()),
1611
+ dnnl::memory::format_tag::acbd};
1612
+ }
1588
1613
} else {
1589
1614
expected_md = make_dnnl_memory_desc (out_lt);
1590
1615
}
0 commit comments