Skip to content

Commit 115eb23

Browse files
[TRANSFORMATIONS] Fix RoPEFusion pattern for GLM4 model on GPU (#29459)
### Details: RoPEFusionChatGLM pattern on GPU uses a Constant as a stop value for StridedSlice instead of a subgraph as on CPU. Cover this is pattern. ### Tickets: - [CVS-164016](https://jira.devtools.intel.com/browse/CVS-164016) Signed-off-by: Andrii Staikov <andrii.staikov@intel.com>
1 parent 7a528c9 commit 115eb23

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,8 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
635635
auto ScatterUpdate = makePattern<opset3::ScatterUpdate>({{0, 0}, {1}, seq_length, {0}}, {});
636636
auto slice_Slice_449_1d = makePattern<ov::opset8::Slice>({cos_sin_cache, {0}, seq_length, {1}, {1}});
637637
auto slice_Slice_449_2d = makePattern<ov::opset8::Slice>({cos_sin_cache, {0, 0}, ScatterUpdate, {1, 1}, {0}});
638-
auto slice_StridedSlice_449 = GenStridedSlice(cos_sin_cache, {0, 0}, ScatterUpdate, {1, 1}, 1);
638+
auto ss_stop = makePattern<opset1::Constant>({}, {});
639+
auto slice_StridedSlice_449 = GenStridedSlice(cos_sin_cache, {0, 0}, ss_stop | ScatterUpdate, {1, 1}, 1);
639640

640641
// [batch, 1, seq_length, half_rotary_dims, 2]
641642
view_Reshape_460 = makePattern<opset1::Reshape>(

src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp

+86
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "openvino/core/node_vector.hpp"
1212
#include "openvino/opsets/opset1.hpp"
1313
#include "openvino/opsets/opset3.hpp"
14+
#include "ov_ops/rms.hpp"
1415
#include "ov_ops/rotary_positional_embeddings.hpp"
1516
#include "ov_ops/type_relaxed.hpp"
1617
#include "transformations/utils/gen_pattern.hpp"
@@ -1452,3 +1453,88 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM4_PagedAttention) {
14521453
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, input1});
14531454
}
14541455
}
1456+
1457+
TEST_F(TransformationTestsF, ConvertToROPE_chatGLM4_PagedAttention_GPU) {
1458+
using namespace ov;
1459+
{
1460+
disable_rt_info_check();
1461+
auto input0 = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1, 4608});
1462+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1, 32, 2});
1463+
1464+
auto ListUnpack = makeOP<opset1::VariadicSplit>({input0, -1, {4096, 256, 256}});
1465+
auto aten_transpose =
1466+
makeOP<opset1::Reshape>({ListUnpack->output(0), {-1, 32, 1, 128}}, {{"special_zero", false}});
1467+
auto aten_slice_Slice =
1468+
makeOP<opset1::StridedSlice>({aten_transpose, {0, 0, 0, 0}, {0, 0, 0, 64}, {1, 1, 1, 1}},
1469+
{{"begin_mask", {1, 1, 1, 0}},
1470+
{"end_mask", {1, 1, 1, 0}},
1471+
{"new_axis_mask", {}},
1472+
{"shrink_axis_mask", {}},
1473+
{"ellipsis_mask", {}}});
1474+
auto aten_reshape_Reshape =
1475+
makeOP<opset1::Reshape>({aten_slice_Slice, {0, 32, 0, 32, 2}}, {{"special_zero", true}});
1476+
auto aten_select_Gather = makeOP<opset8::Gather>({aten_reshape_Reshape, 0, -1}, {{"batch_dims", 0}});
1477+
auto aten_slice_Slice_2 = makeOP<opset1::StridedSlice>({input1, {0, 0}, {0, 1}, {1, 1}},
1478+
{{"begin_mask", {1, 0}},
1479+
{"end_mask", {1, 0}},
1480+
{"new_axis_mask", {}},
1481+
{"shrink_axis_mask", {}},
1482+
{"ellipsis_mask", {}}});
1483+
auto aten_view_Reshape =
1484+
makeOP<opset1::Reshape>({aten_slice_Slice_2, {-1, 1, 1, 32, 2}}, {{"special_zero", false}});
1485+
auto aten_select_Gather_1 = makeOP<opset8::Gather>({aten_view_Reshape, 0, -1}, {{"batch_dims", 0}});
1486+
auto aten_mul_Multiply_1 =
1487+
makeOP<opset1::Multiply>({aten_select_Gather, aten_select_Gather_1}, {{"auto_broadcast", "numpy"}});
1488+
auto aten_select_Gather_2 = makeOP<opset8::Gather>({aten_reshape_Reshape, 1, -1}, {{"batch_dims", 0}});
1489+
auto aten_select_Gather_3 = makeOP<opset8::Gather>({aten_view_Reshape, 1, -1}, {{"batch_dims", 0}});
1490+
auto aten_mul_Multiply_2 =
1491+
makeOP<opset1::Multiply>({aten_select_Gather_2, aten_select_Gather_3}, {{"auto_broadcast", "numpy"}});
1492+
auto Constant_56849 = makeConst(element::f16, ov::Shape({}), {-1});
1493+
auto Multiply_56850 =
1494+
makeOP<opset1::Multiply>({aten_mul_Multiply_2, Constant_56849}, {{"auto_broadcast", "numpy"}});
1495+
auto aten_sub_Subtract =
1496+
makeOP<opset1::Add>({aten_mul_Multiply_1, Multiply_56850}, {{"auto_broadcast", "numpy"}});
1497+
auto Unsqueeze_81695 =
1498+
makeOP<opset1::Reshape>({aten_sub_Subtract, {-1, 32, 1, 32, 1}}, {{"special_zero", false}});
1499+
auto aten_mul_Multiply_3 =
1500+
makeOP<opset1::Multiply>({aten_select_Gather_2, aten_select_Gather_1}, {{"auto_broadcast", "numpy"}});
1501+
auto aten_mul_Multiply_4 =
1502+
makeOP<opset1::Multiply>({aten_select_Gather, aten_select_Gather_3}, {{"auto_broadcast", "numpy"}});
1503+
auto aten_add_Add =
1504+
makeOP<opset1::Add>({aten_mul_Multiply_3, aten_mul_Multiply_4}, {{"auto_broadcast", "numpy"}});
1505+
auto Unsqueeze_81696 = makeOP<opset1::Reshape>({aten_add_Add, {-1, 32, 1, 32, 1}}, {{"special_zero", false}});
1506+
auto aten_stack = makeOP<opset1::Concat>({Unsqueeze_81695, Unsqueeze_81696}, {{"axis", -1}});
1507+
auto aten_flatten_Reshape = makeOP<opset1::Reshape>({aten_stack, {0, 32, 0, 64}}, {{"special_zero", true}});
1508+
auto aten_slice_Slice_3 =
1509+
makeOP<opset1::StridedSlice>({aten_transpose, {0, 0, 0, 64}, {0, 0, 0, INT_MAX}, {1, 1, 1, 1}},
1510+
{{"begin_mask", {1, 1, 1, 0}},
1511+
{"end_mask", {1, 1, 1, 0}},
1512+
{"new_axis_mask", {}},
1513+
{"shrink_axis_mask", {}},
1514+
{"ellipsis_mask", {}}});
1515+
auto aten_cat_Concat = makeOP<opset1::Concat>({aten_flatten_Reshape, aten_slice_Slice_3}, {{"axis", -1}});
1516+
1517+
model = std::make_shared<ov::Model>(ov::NodeVector{aten_cat_Concat}, ov::ParameterVector{input0, input1});
1518+
}
1519+
manager.register_pass<ov::pass::RoPEFusion>(true);
1520+
{
1521+
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1, 4608});
1522+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1, 32, 2});
1523+
1524+
auto rope = makeOP<ov::op::internal::RoPE>({input, input1, input1},
1525+
{{"config.slice_start", 0},
1526+
{"config.slice_stop", 4096},
1527+
{"config.input_trans0213", false},
1528+
{"config.output_trans0213", false},
1529+
{"config.is_interleaved", false},
1530+
{"config.rotary_ndims", 64},
1531+
{"config.is_chatglm", true},
1532+
{"config.support_2d_rope", true},
1533+
{"config.is_qwen", false},
1534+
{"config.head_cnt", 32},
1535+
{"config.head_size", 128},
1536+
{"config.gather_position_arg_id", 0}});
1537+
1538+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, input1});
1539+
}
1540+
}

0 commit comments

Comments
 (0)