Skip to content

Commit 91fe1a8

Browse files
[GPU] Prevent dynamic padding propagation to transposed gemm layer
1 parent c664ca7 commit 91fe1a8

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

src/plugins/intel_gpu/src/graph/include/reshape_inst.h

+23-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "openvino/core/partial_shape.hpp"
99
#include "crop_inst.h"
1010
#include "rope_inst.h"
11+
#include "gemm_inst.h"
1112
#include "primitive_inst.h"
1213

1314
#include <string>
@@ -43,9 +44,30 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {
4344
return false;
4445

4546
// TODO: If user is RoPE and dynamic padding exists, ouput padding propagation is not supported in the base mode
46-
if (get_users().size() == 1 && get_users().front()->is_type<rope>())
47+
auto user = get_users().front();
48+
if (get_users().size() == 1 && user->is_type<rope>())
4749
return false;
4850

51+
// TODO: Support transpose-fused gemm with dynamic_pad
52+
if (get_users().size() == 1 && user->is_type<gemm>()) {
53+
auto desc = user->as<gemm>().get_primitive();
54+
55+
auto input_order_transposed = [&]() -> bool {
56+
for (size_t i = 0; i < desc->input0_transpose_order.size(); i++) {
57+
if (desc->input0_transpose_order[i] != static_cast<int64_t>(i))
58+
return true;
59+
}
60+
for (size_t i = 0; i < desc->input1_transpose_order.size(); i++) {
61+
if (desc->input1_transpose_order[i] != static_cast<int64_t>(i))
62+
return true;
63+
}
64+
return false;
65+
};
66+
67+
if (input_order_transposed())
68+
return false;
69+
}
70+
4971
auto axis = input().as<crop>().get_primitive()->axis;
5072
const auto& input_pshape = input().get_output_layout(false).get_partial_shape();
5173
auto input_rank = input_pshape.size();

0 commit comments

Comments
 (0)