diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp index 397746c75bb84d..bdd1bee0671bb8 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp @@ -72,8 +72,14 @@ ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output& p auto p_neg_const = wrap_type(); auto p_neg_mul = wrap_type({p_current_len, p_neg_const}); + + // For it has always been a constant, but this may change in the future. + // In case of model being in FP16, there will be a decompressing subgraph: + // i.e. Constant -> Convert -> Slice + // It hasn't been observed yet, but, theoretically, there can also be a + // dequantizing subgraph, so it's going to be any_input() here. + auto p_rotary_emb_sincos = pattern::any_input(); // the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128] - auto p_rotary_emb_sincos = wrap_type(); auto p_slice_1 = wrap_type({p_rotary_emb_sincos, _const(), p_opt_reshape, _const(), _const()}); auto p_slice_2 = wrap_type({p_slice_1, p_neg_mul, _const(), _const(), _const()}); diff --git a/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp b/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp index 840309993c939a..7075d7e2a5e31c 100644 --- a/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp +++ b/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp @@ -33,6 +33,8 @@ #include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp" #include "transformations/utils/gen_pattern.hpp" #include "transformations/utils/print_model.hpp" +#include "openvino/pass/visualize_tree.hpp" +#include "transformations/convert_precision.hpp" using namespace ov; using namespace std; @@ -184,8 +186,12 @@ class Qwen7bChatSDPA { static std::shared_ptr gen_rope_emb_sin(const std::shared_ptr& total_seq_len, const std::shared_ptr& neg_mul, - std::shared_ptr& head_size) { - auto sin = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); + std::shared_ptr& head_size, + element::Type model_precision) { + auto sin = makeConst(model_precision, {1, 4096, 1, 128}, MOCK_VALUE); + if (model_precision != element::f32) { + sin = makeOP({sin}, {dest_type_f32}); + } auto sliced_sin_by_total = makeOP({sin, {0}, total_seq_len, {1}, {1}}); auto rotary_emb_sin_shape = makeOP({sliced_sin_by_total}, {{"output_type", "i64"}}); head_size = makeOP({rotary_emb_sin_shape, {3}, 0}, {{"batch_dims", 0}}); @@ -193,8 +199,12 @@ class Qwen7bChatSDPA { } static std::shared_ptr gen_rope_emb_cos(const std::shared_ptr& total_seq_len, - const std::shared_ptr& neg_mul) { - auto cos = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); + const std::shared_ptr& neg_mul, + element::Type model_precision) { + auto cos = makeConst(model_precision, {1, 4096, 1, 128}, MOCK_VALUE); + if (model_precision != element::f32) { + cos = makeOP({cos}, {dest_type_f32}); + } auto sliced_cos_by_total = makeOP({cos, {0}, total_seq_len, {1}, {1}}); return makeOP({sliced_cos_by_total, neg_mul, {LLONG_MAX}, {1}, {1}}); } @@ -341,8 +351,12 @@ class Qwen7bChatPA { static std::shared_ptr gen_rope_emb_sin(const std::shared_ptr& max_context_len, const std::shared_ptr& position_ids, - std::shared_ptr& head_size) { - auto sin = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); + std::shared_ptr& head_size, + element::Type model_precision) { + auto sin = makeConst(model_precision, {1, 4096, 1, 128}, MOCK_VALUE); + if (model_precision != element::f32) { + sin = makeOP({sin}, {dest_type_f32}); + } auto slice_sin = makeOP({sin, position_ids, 1}, {{"batch_dims", 0}}); auto slice = makeOP({sin, {0}, max_context_len, {1}, {1}}); @@ -353,8 +367,12 @@ class Qwen7bChatPA { } static std::shared_ptr gen_rope_emb_cos(const std::shared_ptr& max_context_len, - const std::shared_ptr& position_ids) { - auto cos = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); + const std::shared_ptr& position_ids, + element::Type model_precision) { + auto cos = makeConst(model_precision, {1, 4096, 1, 128}, MOCK_VALUE); + if (model_precision != element::f32) { + cos = makeOP({cos}, {dest_type_f32}); + } auto slice = makeOP({cos, position_ids, 1}, {{"batch_dims", 0}}); return makeOP({slice, {-1, 1, 1, 128}}, {{"special_zero", false}}); } @@ -423,8 +441,12 @@ class Qwen7bChatPA { } // namespace -TEST_F(TransformationTestsF, SDPAToPA_Qwen) { +class SDPAToPATest : public TransformationTestsF, public ::testing::WithParamInterface {}; + +TEST_P(SDPAToPATest, SDPAToPA_Qwen) { + const auto model_precision = GetParam(); { + // Inputs to SDPA transformer: auto beam_idx = makeOP({}, {{"shape", PartialShape{DYN}}, el_type_i64}); auto position_ids = makeOP({}, {{"shape", PartialShape{DYN, DYN}}, el_type_i64}); @@ -453,8 +475,8 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) { // RoPE emb sin/cos init: auto neg_cur_seq_len = Qwen7bChatSDPA::neg_mul(current_seq_len); auto head_size = shared_ptr(); - auto rope_emb_sin = Qwen7bChatSDPA::gen_rope_emb_sin(total_seq_len, neg_cur_seq_len, head_size); - auto rope_emb_cos = Qwen7bChatSDPA::gen_rope_emb_cos(total_seq_len, neg_cur_seq_len); + auto rope_emb_sin = Qwen7bChatSDPA::gen_rope_emb_sin(total_seq_len, neg_cur_seq_len, head_size, model_precision); + auto rope_emb_cos = Qwen7bChatSDPA::gen_rope_emb_cos(total_seq_len, neg_cur_seq_len, model_precision); // RoPE for Q,K inputs: auto rope_q = Qwen7bChatSDPA::gen_rope(QKV::Q, qkv_proj, head_size, rope_emb_sin, rope_emb_cos); @@ -477,6 +499,7 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) { auto res = makeOP({sdpa}); model = std::make_shared(OutputVector{res}, params); + manager.register_pass("check.svg"); manager.register_pass(); } @@ -513,8 +536,8 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) { // RoPE emb sin/cos init: auto head_size = shared_ptr(); - auto rope_emb_sin = Qwen7bChatPA::gen_rope_emb_sin(max_context_len_aligned, position_ids_aligned, head_size); - auto rope_emb_cos = Qwen7bChatPA::gen_rope_emb_cos(max_context_len_aligned, position_ids_aligned); + auto rope_emb_sin = Qwen7bChatPA::gen_rope_emb_sin(max_context_len_aligned, position_ids_aligned, head_size, model_precision); + auto rope_emb_cos = Qwen7bChatPA::gen_rope_emb_cos(max_context_len_aligned, position_ids_aligned, model_precision); // rope Q, K: auto rope_Q = Qwen7bChatPA::gen_rope(QKV::Q, qkv_proj, head_size, rope_emb_sin, rope_emb_cos); @@ -616,3 +639,9 @@ TEST_F(TransformationTestsF, SDPAToPA_TotalSequenceLengthPatternQwen) { disable_result_friendly_names_check(); disable_rt_info_check(); } + +const std::vector element_types = {element::f16, element::f32}; + +INSTANTIATE_TEST_SUITE_P(SDPAToPATest_Conversion, + SDPAToPATest, + testing::ValuesIn(element_types)); \ No newline at end of file