Skip to content

Commit e463ae4

Browse files
add test
1 parent e0af233 commit e463ae4

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp

+95
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include "openvino/core/node_vector.hpp"
1212
#include "openvino/opsets/opset1.hpp"
1313
#include "openvino/opsets/opset3.hpp"
14+
#include "openvino/pass/visualize_tree.hpp"
15+
#include "ov_ops/rms.hpp"
1416
#include "ov_ops/rotary_positional_embeddings.hpp"
1517
#include "ov_ops/type_relaxed.hpp"
1618
#include "transformations/utils/gen_pattern.hpp"
@@ -1452,3 +1454,96 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM4_PagedAttention) {
14521454
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, input1});
14531455
}
14541456
}
1457+
1458+
TEST_F(TransformationTestsF, ConvertToROPE_chatGLM4_PagedAttention_GPU) {
1459+
using namespace ov;
1460+
{
1461+
disable_rt_info_check();
1462+
auto input0 = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1, 4608});
1463+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1, 32, 2});
1464+
1465+
auto ListUnpack =
1466+
makeOP<opset1::VariadicSplit>({input0, -1, {4096, 256, 256}});
1467+
auto aten_transpose = makeOP<opset1::Reshape>(
1468+
{ListUnpack->output(0), {-1, 32, 1, 128}},
1469+
{{"special_zero", false}});
1470+
auto aten_slice_Slice =
1471+
makeOP<opset1::StridedSlice>({aten_transpose,
1472+
{0, 0, 0, 0},
1473+
{0, 0, 0, 64},
1474+
{1, 1, 1, 1}},
1475+
{{"begin_mask", {1, 1, 1, 0}},
1476+
{"end_mask", {1, 1, 1, 0}},
1477+
{"new_axis_mask", {}},
1478+
{"shrink_axis_mask", {}},
1479+
{"ellipsis_mask", {}}});
1480+
auto aten_reshape_Reshape =
1481+
makeOP<opset1::Reshape>({aten_slice_Slice, {0, 32, 0, 32, 2}}, {{"special_zero", true}});
1482+
auto aten_select_Gather = makeOP<opset8::Gather>({aten_reshape_Reshape, 0, -1}, {{"batch_dims", 0}});
1483+
auto aten_slice_Slice_2 = makeOP<opset1::StridedSlice>({input1, {0, 0}, {0, 1}, {1, 1}},
1484+
{{"begin_mask", {1, 0}},
1485+
{"end_mask", {1, 0}},
1486+
{"new_axis_mask", {}},
1487+
{"shrink_axis_mask", {}},
1488+
{"ellipsis_mask", {}}});
1489+
auto aten_view_Reshape =
1490+
makeOP<opset1::Reshape>({aten_slice_Slice_2, {-1, 1, 1, 32, 2}}, {{"special_zero", false}});
1491+
auto aten_select_Gather_1 = makeOP<opset8::Gather>({aten_view_Reshape, 0, -1}, {{"batch_dims", 0}});
1492+
auto aten_mul_Multiply_1 =
1493+
makeOP<opset1::Multiply>({aten_select_Gather, aten_select_Gather_1}, {{"auto_broadcast", "numpy"}});
1494+
auto aten_select_Gather_2 = makeOP<opset8::Gather>({aten_reshape_Reshape, 1, -1}, {{"batch_dims", 0}});
1495+
auto aten_select_Gather_3 = makeOP<opset8::Gather>({aten_view_Reshape, 1, -1}, {{"batch_dims", 0}});
1496+
auto aten_mul_Multiply_2 =
1497+
makeOP<opset1::Multiply>({aten_select_Gather_2, aten_select_Gather_3}, {{"auto_broadcast", "numpy"}});
1498+
auto Constant_56849 = makeConst(element::f16, ov::Shape({}), {-1});
1499+
auto Multiply_56850 =
1500+
makeOP<opset1::Multiply>({aten_mul_Multiply_2, Constant_56849}, {{"auto_broadcast", "numpy"}});
1501+
auto aten_sub_Subtract =
1502+
makeOP<opset1::Add>({aten_mul_Multiply_1, Multiply_56850}, {{"auto_broadcast", "numpy"}});
1503+
auto Unsqueeze_81695 =
1504+
makeOP<opset1::Reshape>({aten_sub_Subtract, {-1, 32, 1, 32, 1}}, {{"special_zero", false}});
1505+
auto aten_mul_Multiply_3 =
1506+
makeOP<opset1::Multiply>({aten_select_Gather_2, aten_select_Gather_1}, {{"auto_broadcast", "numpy"}});
1507+
auto aten_mul_Multiply_4 =
1508+
makeOP<opset1::Multiply>({aten_select_Gather, aten_select_Gather_3}, {{"auto_broadcast", "numpy"}});
1509+
auto aten_add_Add =
1510+
makeOP<opset1::Add>({aten_mul_Multiply_3, aten_mul_Multiply_4}, {{"auto_broadcast", "numpy"}});
1511+
auto Unsqueeze_81696 = makeOP<opset1::Reshape>({aten_add_Add, {-1, 32, 1, 32, 1}}, {{"special_zero", false}});
1512+
auto aten_stack = makeOP<opset1::Concat>({Unsqueeze_81695, Unsqueeze_81696}, {{"axis", -1}});
1513+
auto aten_flatten_Reshape = makeOP<opset1::Reshape>({aten_stack, {0, 32, 0, 64}}, {{"special_zero", true}});
1514+
auto aten_slice_Slice_3 =
1515+
makeOP<opset1::StridedSlice>({aten_transpose,
1516+
{0, 0, 0, 64},
1517+
{0, 0, 0, INT_MAX},
1518+
{1, 1, 1, 1}},
1519+
{{"begin_mask", {1, 1, 1, 0}},
1520+
{"end_mask", {1, 1, 1, 0}},
1521+
{"new_axis_mask", {}},
1522+
{"shrink_axis_mask", {}},
1523+
{"ellipsis_mask", {}}});
1524+
auto aten_cat_Concat = makeOP<opset1::Concat>({aten_flatten_Reshape, aten_slice_Slice_3}, {{"axis", -1}});
1525+
1526+
model = std::make_shared<ov::Model>(ov::NodeVector{aten_cat_Concat}, ov::ParameterVector{input0, input1});
1527+
}
1528+
manager.register_pass<ov::pass::RoPEFusion>(true);
1529+
{
1530+
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1, 4608});
1531+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1, 32, 2});
1532+
1533+
auto rope = makeOP<ov::op::internal::RoPE>({input, input1, input1},
1534+
{{"config.slice_start", 0},
1535+
{"config.slice_stop", 4096},
1536+
{"config.input_trans0213", false},
1537+
{"config.output_trans0213", false},
1538+
{"config.is_interleaved", false},
1539+
{"config.rotary_ndims", 64},
1540+
{"config.is_chatglm", true},
1541+
{"config.support_2d_rope", true},
1542+
{"config.is_qwen", false},
1543+
{"config.head_cnt", 32},
1544+
{"config.head_size", 128},
1545+
{"config.gather_position_arg_id", 0}});
1546+
1547+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, input1});
1548+
}
1549+
}

0 commit comments

Comments
 (0)