|
31 | 31 | using namespace ov;
|
32 | 32 | using namespace testing;
|
33 | 33 |
|
34 |
| -const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition( |
35 |
| - std::shared_ptr<ov::Node> query, |
36 |
| - std::shared_ptr<ov::Node> key, |
37 |
| - std::shared_ptr<ov::Node> value, |
38 |
| - std::shared_ptr<ov::Node> attention_mask, |
39 |
| - std::shared_ptr<ov::Node> scale, |
40 |
| - bool casual); |
| 34 | +const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(std::shared_ptr<ov::Node> query, |
| 35 | + std::shared_ptr<ov::Node> key, |
| 36 | + std::shared_ptr<ov::Node> value, |
| 37 | + std::shared_ptr<ov::Node> attention_mask, |
| 38 | + std::shared_ptr<ov::Node> scale, |
| 39 | + bool casual); |
41 | 40 |
|
42 | 41 | TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionStaticBasic) {
|
43 | 42 | const PartialShape query_shape{1, 32, 32};
|
@@ -187,13 +186,12 @@ TEST_F(TransformationTestsF, ScaledDotProductAttentionDecompositionDynamic) {
|
187 | 186 | }
|
188 | 187 | }
|
189 | 188 |
|
190 |
| -const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition( |
191 |
| - std::shared_ptr<ov::Node> query, |
192 |
| - std::shared_ptr<ov::Node> key, |
193 |
| - std::shared_ptr<ov::Node> value, |
194 |
| - std::shared_ptr<ov::Node> attention_mask, |
195 |
| - std::shared_ptr<ov::Node> scale, |
196 |
| - bool casual) { |
| 189 | +const std::shared_ptr<ov::Node> scaled_dot_product_attention_decomposition(std::shared_ptr<ov::Node> query, |
| 190 | + std::shared_ptr<ov::Node> key, |
| 191 | + std::shared_ptr<ov::Node> value, |
| 192 | + std::shared_ptr<ov::Node> attention_mask, |
| 193 | + std::shared_ptr<ov::Node> scale, |
| 194 | + bool casual) { |
197 | 195 | const auto q_shape = std::make_shared<ov::op::v3::ShapeOf>(query, element::i32);
|
198 | 196 | const auto k_shape = std::make_shared<ov::op::v3::ShapeOf>(key, element::i32);
|
199 | 197 | const auto minus_one = ov::op::v0::Constant::create(element::i32, Shape{}, {-1});
|
|
0 commit comments