Skip to content

Commit c3152d3

Browse files
authored
[GPU] Fix lack of support for 3D (and smaller) shapes in TransposeFusion pass (#26440)
### Details: - Fix lack of support for 3D (and smaller) shapes in TransposeFusion pass - Update IncreasePositionIdsPrecision to support both MatMul and Gemm operations ### Tickets: - [CVS-146889](https://jira.devtools.intel.com/browse/CVS-146889)
1 parent 8c9d4be commit c3152d3

File tree

4 files changed

+117
-6
lines changed

4 files changed

+117
-6
lines changed

src/plugins/intel_gpu/src/plugin/transformations/increase_position_ids_precision.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ IncreasePositionIdsPrecision::IncreasePositionIdsPrecision() {
2929
using namespace ov::pass::pattern;
3030
using ov::pass::pattern::op::Or;
3131

32-
auto gemm = wrap_type<ov::intel_gpu::op::Gemm>();
33-
auto concat = wrap_type<ov::op::v0::Concat>({gemm, gemm});
32+
auto gemm_or_matmul = wrap_type<ov::intel_gpu::op::Gemm, ov::op::v0::MatMul>();
33+
auto concat = wrap_type<ov::op::v0::Concat>({gemm_or_matmul, gemm_or_matmul});
3434
auto sin = wrap_type<ov::op::v0::Sin>({concat});
3535
auto cos = wrap_type<ov::op::v0::Cos>({concat});
3636

@@ -50,15 +50,15 @@ IncreasePositionIdsPrecision::IncreasePositionIdsPrecision() {
5050
ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
5151
const auto& pattern_map = m.get_pattern_value_map();
5252

53-
auto gemm_node = std::dynamic_pointer_cast<ov::intel_gpu::op::Gemm>(pattern_map.at(gemm).get_node_shared_ptr());
53+
auto matmul_node = std::dynamic_pointer_cast<ov::op::v0::MatMul>(pattern_map.at(gemm_or_matmul).get_node_shared_ptr());
5454
auto cos_node = std::dynamic_pointer_cast<ov::op::v0::Cos>(pattern_map.at(cos).get_node_shared_ptr());
5555
auto sin_node = std::dynamic_pointer_cast<ov::op::v0::Sin>(pattern_map.at(sin).get_node_shared_ptr());
5656

57-
if (!gemm_node || transformation_callback(gemm_node))
57+
if (!matmul_node || transformation_callback(matmul_node))
5858
return false;
5959

6060
const auto desired_et = ov::element::f32;
61-
const auto original_et = gemm_node->get_output_element_type(0);
61+
const auto original_et = matmul_node->get_output_element_type(0);
6262
if (original_et == desired_et)
6363
return false;
6464

@@ -112,7 +112,7 @@ IncreasePositionIdsPrecision::IncreasePositionIdsPrecision() {
112112
}
113113
};
114114

115-
bool is_changed = insert_converts_before_if_needed(gemm_node);
115+
bool is_changed = insert_converts_before_if_needed(matmul_node);
116116

117117
if (is_changed) {
118118
insert_converts_after_if_needed(cos_node);

src/plugins/intel_gpu/src/plugin/transformations/transpose_fusion.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ bool has_optimized_version(const ov::Output<ov::Node>& output, bool supports_imm
5252
{3, 0, 1, 2},
5353
};
5454

55+
const auto expected_dims_num = 4;
56+
const auto original_dims_num = transpose_order.size();
57+
if (original_dims_num < expected_dims_num) {
58+
transpose_order.resize(expected_dims_num);
59+
std::iota(transpose_order.begin() + original_dims_num, transpose_order.end(), original_dims_num);
60+
}
61+
5562
if (!cldnn::one_of(transpose_order, allowed_orders))
5663
return false;
5764

src/plugins/intel_gpu/tests/unit/transformations/increase_position_ids_precision_test.cpp

+46
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,49 @@ TEST_F(TransformationTestsF, IncreasePositionIdsPrecisionWithUnsqueeze) {
130130
}
131131
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
132132
}
133+
134+
TEST_F(TransformationTestsF, IncreasePositionIdsMatmulWithoutUnsqueeze) {
135+
{
136+
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1, -1 });
137+
auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32);
138+
auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1}));
139+
auto input_convert_fp = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f16);
140+
auto rotary_embd_const = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{1, 64, 1});
141+
142+
auto matmul = std::make_shared<ov::op::v0::MatMul>(input_convert_fp, rotary_embd_const);
143+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{matmul, matmul}, 2);
144+
145+
auto cos = std::make_shared<ov::op::v0::Cos>(concat);
146+
auto sin = std::make_shared<ov::op::v0::Sin>(concat);
147+
148+
auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic(4));
149+
auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos, sin}, ov::op::internal::RoPE::Config());
150+
151+
model = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input });
152+
manager.register_pass<IncreasePositionIdsPrecision>();
153+
}
154+
{
155+
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1, -1 });
156+
auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32);
157+
auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1}));
158+
auto rotary_embd_const = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{1, 64, 1});
159+
160+
auto input_convert_f32 = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f32);
161+
auto rotary_embd_const_convert_f32 = std::make_shared<ov::op::v0::Convert>(rotary_embd_const, ov::element::f32);
162+
163+
auto matmul = std::make_shared<ov::op::v0::MatMul>(input_convert_f32, rotary_embd_const_convert_f32);
164+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{matmul, matmul}, 2);
165+
166+
auto cos = std::make_shared<ov::op::v0::Cos>(concat);
167+
auto sin = std::make_shared<ov::op::v0::Sin>(concat);
168+
169+
auto cos_convert = std::make_shared<ov::op::v0::Convert>(cos, ov::element::f16);
170+
auto sin_convert = std::make_shared<ov::op::v0::Convert>(sin, ov::element::f16);
171+
172+
auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic(4));
173+
auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos_convert, sin_convert}, ov::op::internal::RoPE::Config());
174+
175+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input });
176+
}
177+
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
178+
}

src/plugins/intel_gpu/tests/unit/transformations/transpose_matmul_fusion_test.cpp

+58
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,64 @@ TEST_F(TransformationTestsF, TranposeMatmulFusion4) {
124124
}
125125
}
126126

127+
TEST_F(TransformationTestsF, TranposeMatmulFusion5) {
128+
{
129+
auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic(3));
130+
auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic(3));
131+
auto matmul = std::make_shared<ov::op::v0::MatMul>(input_a, input_b);
132+
auto tranpose_c_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 2, 1});
133+
auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(matmul, tranpose_c_const);
134+
135+
model = std::make_shared<ov::Model>(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b });
136+
137+
const auto supports_immad = false;
138+
manager.register_pass<TransposeFusion>(supports_immad);
139+
}
140+
{
141+
std::vector<int64_t> order_a = {0, 1, 2};
142+
std::vector<int64_t> order_b = {0, 1, 2};
143+
std::vector<int64_t> order_c = {0, 2, 1};
144+
auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic(3));
145+
auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic(3));
146+
auto gemm = std::make_shared<ov::intel_gpu::op::Gemm>(input_a, input_b, order_a, order_b, order_c, ov::element::undefined);
147+
148+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{ gemm }, ov::ParameterVector{ input_a, input_b });
149+
comparator.enable(FunctionsComparator::ATTRIBUTES);
150+
}
151+
}
152+
153+
TEST_F(TransformationTestsF, TranposeMatmulFusion6) {
154+
auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic(2));
155+
auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic(2));
156+
auto matmul = std::make_shared<ov::op::v0::MatMul>(input_a, input_b);
157+
auto tranpose_c_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{2}, {1, 0});
158+
auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(matmul, tranpose_c_const);
159+
160+
model = std::make_shared<ov::Model>(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b });
161+
162+
const auto supports_immad = false;
163+
manager.register_pass<TransposeFusion>(supports_immad);
164+
165+
model_ref = model->clone();
166+
comparator.enable(FunctionsComparator::ATTRIBUTES);
167+
}
168+
169+
TEST_F(TransformationTestsF, TranposeMatmulFusion7) {
170+
auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{2, 4});
171+
auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{4, 2});
172+
auto matmul = std::make_shared<ov::op::v0::MatMul>(input_a, input_b);
173+
auto tranpose_c_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{2}, {1, 0});
174+
auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(matmul, tranpose_c_const);
175+
176+
model = std::make_shared<ov::Model>(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b });
177+
178+
const auto supports_immad = false;
179+
manager.register_pass<TransposeFusion>(supports_immad);
180+
181+
model_ref = model->clone();
182+
comparator.enable(FunctionsComparator::ATTRIBUTES);
183+
}
184+
127185
TEST_F(TransformationTestsF, TranposeMatmulFusion_Illegal_1) {
128186
{
129187
auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{10, 20});

0 commit comments

Comments
 (0)