11
11
#include " openvino/core/node_vector.hpp"
12
12
#include " openvino/opsets/opset1.hpp"
13
13
#include " openvino/opsets/opset3.hpp"
14
+ #include " openvino/pass/visualize_tree.hpp"
15
+ #include " ov_ops/rms.hpp"
14
16
#include " ov_ops/rotary_positional_embeddings.hpp"
15
17
#include " ov_ops/type_relaxed.hpp"
16
18
#include " transformations/utils/gen_pattern.hpp"
@@ -1452,3 +1454,96 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM4_PagedAttention) {
1452
1454
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, input1});
1453
1455
}
1454
1456
}
1457
+
1458
+ TEST_F (TransformationTestsF, ConvertToROPE_chatGLM4_PagedAttention_GPU) {
1459
+ using namespace ov ;
1460
+ {
1461
+ disable_rt_info_check ();
1462
+ auto input0 = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1 , 4608 });
1463
+ auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1 , 32 , 2 });
1464
+
1465
+ auto ListUnpack =
1466
+ makeOP<opset1::VariadicSplit>({input0, -1 , {4096 , 256 , 256 }});
1467
+ auto aten_transpose = makeOP<opset1::Reshape>(
1468
+ {ListUnpack->output (0 ), {-1 , 32 , 1 , 128 }},
1469
+ {{" special_zero" , false }});
1470
+ auto aten_slice_Slice =
1471
+ makeOP<opset1::StridedSlice>({aten_transpose,
1472
+ {0 , 0 , 0 , 0 },
1473
+ {0 , 0 , 0 , 64 },
1474
+ {1 , 1 , 1 , 1 }},
1475
+ {{" begin_mask" , {1 , 1 , 1 , 0 }},
1476
+ {" end_mask" , {1 , 1 , 1 , 0 }},
1477
+ {" new_axis_mask" , {}},
1478
+ {" shrink_axis_mask" , {}},
1479
+ {" ellipsis_mask" , {}}});
1480
+ auto aten_reshape_Reshape =
1481
+ makeOP<opset1::Reshape>({aten_slice_Slice, {0 , 32 , 0 , 32 , 2 }}, {{" special_zero" , true }});
1482
+ auto aten_select_Gather = makeOP<opset8::Gather>({aten_reshape_Reshape, 0 , -1 }, {{" batch_dims" , 0 }});
1483
+ auto aten_slice_Slice_2 = makeOP<opset1::StridedSlice>({input1, {0 , 0 }, {0 , 1 }, {1 , 1 }},
1484
+ {{" begin_mask" , {1 , 0 }},
1485
+ {" end_mask" , {1 , 0 }},
1486
+ {" new_axis_mask" , {}},
1487
+ {" shrink_axis_mask" , {}},
1488
+ {" ellipsis_mask" , {}}});
1489
+ auto aten_view_Reshape =
1490
+ makeOP<opset1::Reshape>({aten_slice_Slice_2, {-1 , 1 , 1 , 32 , 2 }}, {{" special_zero" , false }});
1491
+ auto aten_select_Gather_1 = makeOP<opset8::Gather>({aten_view_Reshape, 0 , -1 }, {{" batch_dims" , 0 }});
1492
+ auto aten_mul_Multiply_1 =
1493
+ makeOP<opset1::Multiply>({aten_select_Gather, aten_select_Gather_1}, {{" auto_broadcast" , " numpy" }});
1494
+ auto aten_select_Gather_2 = makeOP<opset8::Gather>({aten_reshape_Reshape, 1 , -1 }, {{" batch_dims" , 0 }});
1495
+ auto aten_select_Gather_3 = makeOP<opset8::Gather>({aten_view_Reshape, 1 , -1 }, {{" batch_dims" , 0 }});
1496
+ auto aten_mul_Multiply_2 =
1497
+ makeOP<opset1::Multiply>({aten_select_Gather_2, aten_select_Gather_3}, {{" auto_broadcast" , " numpy" }});
1498
+ auto Constant_56849 = makeConst (element::f16, ov::Shape ({}), {-1 });
1499
+ auto Multiply_56850 =
1500
+ makeOP<opset1::Multiply>({aten_mul_Multiply_2, Constant_56849}, {{" auto_broadcast" , " numpy" }});
1501
+ auto aten_sub_Subtract =
1502
+ makeOP<opset1::Add>({aten_mul_Multiply_1, Multiply_56850}, {{" auto_broadcast" , " numpy" }});
1503
+ auto Unsqueeze_81695 =
1504
+ makeOP<opset1::Reshape>({aten_sub_Subtract, {-1 , 32 , 1 , 32 , 1 }}, {{" special_zero" , false }});
1505
+ auto aten_mul_Multiply_3 =
1506
+ makeOP<opset1::Multiply>({aten_select_Gather_2, aten_select_Gather_1}, {{" auto_broadcast" , " numpy" }});
1507
+ auto aten_mul_Multiply_4 =
1508
+ makeOP<opset1::Multiply>({aten_select_Gather, aten_select_Gather_3}, {{" auto_broadcast" , " numpy" }});
1509
+ auto aten_add_Add =
1510
+ makeOP<opset1::Add>({aten_mul_Multiply_3, aten_mul_Multiply_4}, {{" auto_broadcast" , " numpy" }});
1511
+ auto Unsqueeze_81696 = makeOP<opset1::Reshape>({aten_add_Add, {-1 , 32 , 1 , 32 , 1 }}, {{" special_zero" , false }});
1512
+ auto aten_stack = makeOP<opset1::Concat>({Unsqueeze_81695, Unsqueeze_81696}, {{" axis" , -1 }});
1513
+ auto aten_flatten_Reshape = makeOP<opset1::Reshape>({aten_stack, {0 , 32 , 0 , 64 }}, {{" special_zero" , true }});
1514
+ auto aten_slice_Slice_3 =
1515
+ makeOP<opset1::StridedSlice>({aten_transpose,
1516
+ {0 , 0 , 0 , 64 },
1517
+ {0 , 0 , 0 , INT_MAX},
1518
+ {1 , 1 , 1 , 1 }},
1519
+ {{" begin_mask" , {1 , 1 , 1 , 0 }},
1520
+ {" end_mask" , {1 , 1 , 1 , 0 }},
1521
+ {" new_axis_mask" , {}},
1522
+ {" shrink_axis_mask" , {}},
1523
+ {" ellipsis_mask" , {}}});
1524
+ auto aten_cat_Concat = makeOP<opset1::Concat>({aten_flatten_Reshape, aten_slice_Slice_3}, {{" axis" , -1 }});
1525
+
1526
+ model = std::make_shared<ov::Model>(ov::NodeVector{aten_cat_Concat}, ov::ParameterVector{input0, input1});
1527
+ }
1528
+ manager.register_pass <ov::pass::RoPEFusion>(true );
1529
+ {
1530
+ auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1 , 4608 });
1531
+ auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::PartialShape{DYN, 1 , 32 , 2 });
1532
+
1533
+ auto rope = makeOP<ov::op::internal::RoPE>({input, input1, input1},
1534
+ {{" config.slice_start" , 0 },
1535
+ {" config.slice_stop" , 4096 },
1536
+ {" config.input_trans0213" , false },
1537
+ {" config.output_trans0213" , false },
1538
+ {" config.is_interleaved" , false },
1539
+ {" config.rotary_ndims" , 64 },
1540
+ {" config.is_chatglm" , true },
1541
+ {" config.support_2d_rope" , true },
1542
+ {" config.is_qwen" , false },
1543
+ {" config.head_cnt" , 32 },
1544
+ {" config.head_size" , 128 },
1545
+ {" config.gather_position_arg_id" , 0 }});
1546
+
1547
+ model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, input1});
1548
+ }
1549
+ }
0 commit comments