11
11
#include " openvino/core/node_vector.hpp"
12
12
#include " openvino/opsets/opset1.hpp"
13
13
#include " openvino/opsets/opset3.hpp"
14
+ #include " ov_ops/rms.hpp"
14
15
#include " ov_ops/rotary_positional_embeddings.hpp"
15
16
#include " ov_ops/type_relaxed.hpp"
16
17
#include " transformations/utils/gen_pattern.hpp"
@@ -1452,3 +1453,88 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM4_PagedAttention) {
1452
1453
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, input1});
1453
1454
}
1454
1455
}
1456
+
1457
+ TEST_F (TransformationTestsF, ConvertToROPE_chatGLM4_PagedAttention_GPU) {
1458
+ using namespace ov ;
1459
+ {
1460
+ disable_rt_info_check ();
1461
+ auto input0 = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1 , 4608 });
1462
+ auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1 , 32 , 2 });
1463
+
1464
+ auto ListUnpack = makeOP<opset1::VariadicSplit>({input0, -1 , {4096 , 256 , 256 }});
1465
+ auto aten_transpose =
1466
+ makeOP<opset1::Reshape>({ListUnpack->output (0 ), {-1 , 32 , 1 , 128 }}, {{" special_zero" , false }});
1467
+ auto aten_slice_Slice =
1468
+ makeOP<opset1::StridedSlice>({aten_transpose, {0 , 0 , 0 , 0 }, {0 , 0 , 0 , 64 }, {1 , 1 , 1 , 1 }},
1469
+ {{" begin_mask" , {1 , 1 , 1 , 0 }},
1470
+ {" end_mask" , {1 , 1 , 1 , 0 }},
1471
+ {" new_axis_mask" , {}},
1472
+ {" shrink_axis_mask" , {}},
1473
+ {" ellipsis_mask" , {}}});
1474
+ auto aten_reshape_Reshape =
1475
+ makeOP<opset1::Reshape>({aten_slice_Slice, {0 , 32 , 0 , 32 , 2 }}, {{" special_zero" , true }});
1476
+ auto aten_select_Gather = makeOP<opset8::Gather>({aten_reshape_Reshape, 0 , -1 }, {{" batch_dims" , 0 }});
1477
+ auto aten_slice_Slice_2 = makeOP<opset1::StridedSlice>({input1, {0 , 0 }, {0 , 1 }, {1 , 1 }},
1478
+ {{" begin_mask" , {1 , 0 }},
1479
+ {" end_mask" , {1 , 0 }},
1480
+ {" new_axis_mask" , {}},
1481
+ {" shrink_axis_mask" , {}},
1482
+ {" ellipsis_mask" , {}}});
1483
+ auto aten_view_Reshape =
1484
+ makeOP<opset1::Reshape>({aten_slice_Slice_2, {-1 , 1 , 1 , 32 , 2 }}, {{" special_zero" , false }});
1485
+ auto aten_select_Gather_1 = makeOP<opset8::Gather>({aten_view_Reshape, 0 , -1 }, {{" batch_dims" , 0 }});
1486
+ auto aten_mul_Multiply_1 =
1487
+ makeOP<opset1::Multiply>({aten_select_Gather, aten_select_Gather_1}, {{" auto_broadcast" , " numpy" }});
1488
+ auto aten_select_Gather_2 = makeOP<opset8::Gather>({aten_reshape_Reshape, 1 , -1 }, {{" batch_dims" , 0 }});
1489
+ auto aten_select_Gather_3 = makeOP<opset8::Gather>({aten_view_Reshape, 1 , -1 }, {{" batch_dims" , 0 }});
1490
+ auto aten_mul_Multiply_2 =
1491
+ makeOP<opset1::Multiply>({aten_select_Gather_2, aten_select_Gather_3}, {{" auto_broadcast" , " numpy" }});
1492
+ auto Constant_56849 = makeConst (element::f16, ov::Shape ({}), {-1 });
1493
+ auto Multiply_56850 =
1494
+ makeOP<opset1::Multiply>({aten_mul_Multiply_2, Constant_56849}, {{" auto_broadcast" , " numpy" }});
1495
+ auto aten_sub_Subtract =
1496
+ makeOP<opset1::Add>({aten_mul_Multiply_1, Multiply_56850}, {{" auto_broadcast" , " numpy" }});
1497
+ auto Unsqueeze_81695 =
1498
+ makeOP<opset1::Reshape>({aten_sub_Subtract, {-1 , 32 , 1 , 32 , 1 }}, {{" special_zero" , false }});
1499
+ auto aten_mul_Multiply_3 =
1500
+ makeOP<opset1::Multiply>({aten_select_Gather_2, aten_select_Gather_1}, {{" auto_broadcast" , " numpy" }});
1501
+ auto aten_mul_Multiply_4 =
1502
+ makeOP<opset1::Multiply>({aten_select_Gather, aten_select_Gather_3}, {{" auto_broadcast" , " numpy" }});
1503
+ auto aten_add_Add =
1504
+ makeOP<opset1::Add>({aten_mul_Multiply_3, aten_mul_Multiply_4}, {{" auto_broadcast" , " numpy" }});
1505
+ auto Unsqueeze_81696 = makeOP<opset1::Reshape>({aten_add_Add, {-1 , 32 , 1 , 32 , 1 }}, {{" special_zero" , false }});
1506
+ auto aten_stack = makeOP<opset1::Concat>({Unsqueeze_81695, Unsqueeze_81696}, {{" axis" , -1 }});
1507
+ auto aten_flatten_Reshape = makeOP<opset1::Reshape>({aten_stack, {0 , 32 , 0 , 64 }}, {{" special_zero" , true }});
1508
+ auto aten_slice_Slice_3 =
1509
+ makeOP<opset1::StridedSlice>({aten_transpose, {0 , 0 , 0 , 64 }, {0 , 0 , 0 , INT_MAX}, {1 , 1 , 1 , 1 }},
1510
+ {{" begin_mask" , {1 , 1 , 1 , 0 }},
1511
+ {" end_mask" , {1 , 1 , 1 , 0 }},
1512
+ {" new_axis_mask" , {}},
1513
+ {" shrink_axis_mask" , {}},
1514
+ {" ellipsis_mask" , {}}});
1515
+ auto aten_cat_Concat = makeOP<opset1::Concat>({aten_flatten_Reshape, aten_slice_Slice_3}, {{" axis" , -1 }});
1516
+
1517
+ model = std::make_shared<ov::Model>(ov::NodeVector{aten_cat_Concat}, ov::ParameterVector{input0, input1});
1518
+ }
1519
+ manager.register_pass <ov::pass::RoPEFusion>(true );
1520
+ {
1521
+ auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1 , 4608 });
1522
+ auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1 , 32 , 2 });
1523
+
1524
+ auto rope = makeOP<ov::op::internal::RoPE>({input, input1, input1},
1525
+ {{" config.slice_start" , 0 },
1526
+ {" config.slice_stop" , 4096 },
1527
+ {" config.input_trans0213" , false },
1528
+ {" config.output_trans0213" , false },
1529
+ {" config.is_interleaved" , false },
1530
+ {" config.rotary_ndims" , 64 },
1531
+ {" config.is_chatglm" , true },
1532
+ {" config.support_2d_rope" , true },
1533
+ {" config.is_qwen" , false },
1534
+ {" config.head_cnt" , 32 },
1535
+ {" config.head_size" , 128 },
1536
+ {" config.gather_position_arg_id" , 0 }});
1537
+
1538
+ model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, input1});
1539
+ }
1540
+ }
0 commit comments