@@ -176,3 +176,53 @@ TEST_F(TransformationTestsF, IncreasePositionIdsMatmulWithoutUnsqueeze) {
176
176
}
177
177
comparator.enable (FunctionsComparator::CmpValues::ATTRIBUTES);
178
178
}
179
+
180
+ TEST_F (TransformationTestsF, IncreasePositionIdsReshapeAfterMatmul) {
181
+ {
182
+ auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1 , -1 });
183
+ auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32);
184
+ auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1 }, std::vector<int64_t >{1 }));
185
+ auto input_convert_fp = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f16);
186
+ auto rotary_embd_const = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{1 , 64 , 1 });
187
+ auto reshape_dims = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{3 });
188
+
189
+ auto matmul = std::make_shared<ov::op::v0::MatMul>(input_convert_fp, rotary_embd_const);
190
+ auto reshape = std::make_shared<ov::op::v1::Reshape>(matmul, reshape_dims, true );
191
+ auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{reshape, reshape}, 2 );
192
+
193
+ auto cos = std::make_shared<ov::op::v0::Cos>(concat);
194
+ auto sin = std::make_shared<ov::op::v0::Sin>(concat);
195
+
196
+ auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic (4 ));
197
+ auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos , sin }, ov::op::internal::RoPE::Config ());
198
+
199
+ model = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input, reshape_dims });
200
+ manager.register_pass <IncreasePositionIdsPrecision>();
201
+ }
202
+ {
203
+ auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1 , -1 });
204
+ auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32);
205
+ auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1 }, std::vector<int64_t >{1 }));
206
+ auto rotary_embd_const = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{1 , 64 , 1 });
207
+ auto reshape_dims = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{3 });
208
+
209
+ auto input_convert_f32 = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f32);
210
+ auto rotary_embd_const_convert_f32 = std::make_shared<ov::op::v0::Convert>(rotary_embd_const, ov::element::f32);
211
+
212
+ auto matmul = std::make_shared<ov::op::v0::MatMul>(input_convert_f32, rotary_embd_const_convert_f32);
213
+ auto reshape = std::make_shared<ov::op::v1::Reshape>(matmul, reshape_dims, true );
214
+ auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{reshape, reshape}, 2 );
215
+
216
+ auto cos = std::make_shared<ov::op::v0::Cos>(concat);
217
+ auto sin = std::make_shared<ov::op::v0::Sin>(concat);
218
+
219
+ auto cos_convert = std::make_shared<ov::op::v0::Convert>(cos , ov::element::f16);
220
+ auto sin_convert = std::make_shared<ov::op::v0::Convert>(sin , ov::element::f16);
221
+
222
+ auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic (4 ));
223
+ auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos_convert, sin_convert}, ov::op::internal::RoPE::Config ());
224
+
225
+ model_ref = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input, reshape_dims });
226
+ }
227
+ comparator.enable (FunctionsComparator::CmpValues::ATTRIBUTES);
228
+ }
0 commit comments