Skip to content

Commit 45ce076

Browse files
[GPU] Prevent dynamic padding propagation to transposed gemm layer
1 parent a9c0047 commit 45ce076

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

src/plugins/intel_gpu/src/graph/include/reshape_inst.h

+23-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "openvino/core/partial_shape.hpp"
99
#include "crop_inst.h"
1010
#include "rope_inst.h"
11+
#include "gemm_inst.h"
1112
#include "primitive_inst.h"
1213

1314
#include <string>
@@ -40,9 +41,30 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {
4041
return false;
4142

4243
// TODO: If user is RoPE and dynamic padding exists, ouput padding propagation is not supported in the base mode
43-
if (get_users().size() == 1 && get_users().front()->is_type<rope>())
44+
auto user = get_users().front();
45+
if (get_users().size() == 1 && user->is_type<rope>())
4446
return false;
4547

48+
// TODO: Support transpose-fused gemm with dynamic_pad
49+
if (get_users().size() == 1 && user->is_type<gemm>()) {
50+
auto desc = user->as<gemm>().get_primitive();
51+
52+
auto input_order_transposed = [&]() -> bool {
53+
for (size_t i = 0; i < desc->input0_transpose_order.size(); i++) {
54+
if (desc->input0_transpose_order[i] != static_cast<int64_t>(i))
55+
return true;
56+
}
57+
for (size_t i = 0; i < desc->input1_transpose_order.size(); i++) {
58+
if (desc->input1_transpose_order[i] != static_cast<int64_t>(i))
59+
return true;
60+
}
61+
return false;
62+
};
63+
64+
if (input_order_transposed())
65+
return false;
66+
}
67+
4668
auto axis = input().as<crop>().get_primitive()->axis;
4769
const auto& input_pshape = input().get_output_layout(false).get_partial_shape();
4870
auto input_rank = input_pshape.size();

0 commit comments

Comments
 (0)