Skip to content

Commit

Permalink
fix_qwen_optional_convert
Browse files Browse the repository at this point in the history
  • Loading branch information
CuriousPanCake committed Jan 16, 2025
1 parent 2b0f127 commit f8611e9
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,14 @@ ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output<Node>& p

auto p_neg_const = wrap_type<v0::Constant>();
auto p_neg_mul = wrap_type<v1::Multiply>({p_current_len, p_neg_const});

// For it has always been a constant, but this may change in the future.
// In case of model being in FP16, there will be a decompressing subgraph:
// i.e. Constant -> Convert -> Slice
// It hasn't been observed yet, but, theoretically, there can also be a
// dequantizing subgraph, so it's going to be any_input() here.
auto p_rotary_emb_sincos = pattern::any_input();
// the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128]
auto p_rotary_emb_sincos = wrap_type<v0::Constant>();
auto p_slice_1 = wrap_type<v8::Slice>({p_rotary_emb_sincos, _const(), p_opt_reshape, _const(), _const()});
auto p_slice_2 = wrap_type<v8::Slice>({p_slice_1, p_neg_mul, _const(), _const(), _const()});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp"
#include "transformations/utils/gen_pattern.hpp"
#include "transformations/utils/print_model.hpp"
#include "openvino/pass/visualize_tree.hpp"
#include "transformations/convert_precision.hpp"

using namespace ov;
using namespace std;
Expand Down Expand Up @@ -184,17 +186,25 @@ class Qwen7bChatSDPA {

static std::shared_ptr<Node> gen_rope_emb_sin(const std::shared_ptr<Node>& total_seq_len,
const std::shared_ptr<Node>& neg_mul,
std::shared_ptr<Node>& head_size) {
auto sin = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE);
std::shared_ptr<Node>& head_size,
element::Type model_precision) {
auto sin = makeConst(model_precision, {1, 4096, 1, 128}, MOCK_VALUE);
if (model_precision != element::f32) {
sin = makeOP<v0::Convert>({sin}, {dest_type_f32});
}
auto sliced_sin_by_total = makeOP<v8::Slice>({sin, {0}, total_seq_len, {1}, {1}});
auto rotary_emb_sin_shape = makeOP<v3::ShapeOf>({sliced_sin_by_total}, {{"output_type", "i64"}});
head_size = makeOP<v8::Gather>({rotary_emb_sin_shape, {3}, 0}, {{"batch_dims", 0}});
return makeOP<v8::Slice>({sliced_sin_by_total, neg_mul, {LLONG_MAX}, {1}, {1}});
}

static std::shared_ptr<Node> gen_rope_emb_cos(const std::shared_ptr<Node>& total_seq_len,
const std::shared_ptr<Node>& neg_mul) {
auto cos = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE);
const std::shared_ptr<Node>& neg_mul,
element::Type model_precision) {
auto cos = makeConst(model_precision, {1, 4096, 1, 128}, MOCK_VALUE);
if (model_precision != element::f32) {
cos = makeOP<v0::Convert>({cos}, {dest_type_f32});
}
auto sliced_cos_by_total = makeOP<v8::Slice>({cos, {0}, total_seq_len, {1}, {1}});
return makeOP<v8::Slice>({sliced_cos_by_total, neg_mul, {LLONG_MAX}, {1}, {1}});
}
Expand Down Expand Up @@ -341,8 +351,12 @@ class Qwen7bChatPA {

static std::shared_ptr<Node> gen_rope_emb_sin(const std::shared_ptr<Node>& max_context_len,
const std::shared_ptr<Node>& position_ids,
std::shared_ptr<Node>& head_size) {
auto sin = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE);
std::shared_ptr<Node>& head_size,
element::Type model_precision) {
auto sin = makeConst(model_precision, {1, 4096, 1, 128}, MOCK_VALUE);
if (model_precision != element::f32) {
sin = makeOP<v0::Convert>({sin}, {dest_type_f32});
}
auto slice_sin = makeOP<v8::Gather>({sin, position_ids, 1}, {{"batch_dims", 0}});

auto slice = makeOP<v8::Slice>({sin, {0}, max_context_len, {1}, {1}});
Expand All @@ -353,8 +367,12 @@ class Qwen7bChatPA {
}

static std::shared_ptr<Node> gen_rope_emb_cos(const std::shared_ptr<Node>& max_context_len,
const std::shared_ptr<Node>& position_ids) {
auto cos = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE);
const std::shared_ptr<Node>& position_ids,
element::Type model_precision) {
auto cos = makeConst(model_precision, {1, 4096, 1, 128}, MOCK_VALUE);
if (model_precision != element::f32) {
cos = makeOP<v0::Convert>({cos}, {dest_type_f32});
}
auto slice = makeOP<v8::Gather>({cos, position_ids, 1}, {{"batch_dims", 0}});
return makeOP<v1::Reshape>({slice, {-1, 1, 1, 128}}, {{"special_zero", false}});
}
Expand Down Expand Up @@ -423,8 +441,12 @@ class Qwen7bChatPA {

} // namespace

TEST_F(TransformationTestsF, SDPAToPA_Qwen) {
class SDPAToPATest : public TransformationTestsF, public ::testing::WithParamInterface<element::Type> {};

TEST_P(SDPAToPATest, SDPAToPA_Qwen) {
const auto model_precision = GetParam();
{

// Inputs to SDPA transformer:
auto beam_idx = makeOP<v0::Parameter>({}, {{"shape", PartialShape{DYN}}, el_type_i64});
auto position_ids = makeOP<v0::Parameter>({}, {{"shape", PartialShape{DYN, DYN}}, el_type_i64});
Expand Down Expand Up @@ -453,8 +475,8 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) {
// RoPE emb sin/cos init:
auto neg_cur_seq_len = Qwen7bChatSDPA::neg_mul(current_seq_len);
auto head_size = shared_ptr<Node>();
auto rope_emb_sin = Qwen7bChatSDPA::gen_rope_emb_sin(total_seq_len, neg_cur_seq_len, head_size);
auto rope_emb_cos = Qwen7bChatSDPA::gen_rope_emb_cos(total_seq_len, neg_cur_seq_len);
auto rope_emb_sin = Qwen7bChatSDPA::gen_rope_emb_sin(total_seq_len, neg_cur_seq_len, head_size, model_precision);
auto rope_emb_cos = Qwen7bChatSDPA::gen_rope_emb_cos(total_seq_len, neg_cur_seq_len, model_precision);

// RoPE for Q,K inputs:
auto rope_q = Qwen7bChatSDPA::gen_rope(QKV::Q, qkv_proj, head_size, rope_emb_sin, rope_emb_cos);
Expand All @@ -477,6 +499,7 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) {
auto res = makeOP<v0::Result>({sdpa});

model = std::make_shared<ov::Model>(OutputVector{res}, params);
manager.register_pass<ov::pass::VisualizeTree>("check.svg");
manager.register_pass<ov::pass::SDPAToPagedAttention>();
}

Expand Down Expand Up @@ -513,8 +536,8 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) {

// RoPE emb sin/cos init:
auto head_size = shared_ptr<Node>();
auto rope_emb_sin = Qwen7bChatPA::gen_rope_emb_sin(max_context_len_aligned, position_ids_aligned, head_size);
auto rope_emb_cos = Qwen7bChatPA::gen_rope_emb_cos(max_context_len_aligned, position_ids_aligned);
auto rope_emb_sin = Qwen7bChatPA::gen_rope_emb_sin(max_context_len_aligned, position_ids_aligned, head_size, model_precision);
auto rope_emb_cos = Qwen7bChatPA::gen_rope_emb_cos(max_context_len_aligned, position_ids_aligned, model_precision);

// rope Q, K:
auto rope_Q = Qwen7bChatPA::gen_rope(QKV::Q, qkv_proj, head_size, rope_emb_sin, rope_emb_cos);
Expand Down Expand Up @@ -616,3 +639,9 @@ TEST_F(TransformationTestsF, SDPAToPA_TotalSequenceLengthPatternQwen) {
disable_result_friendly_names_check();
disable_rt_info_check();
}

const std::vector<ov::element::Type> element_types = {element::f16, element::f32};

INSTANTIATE_TEST_SUITE_P(SDPAToPATest_Conversion,
SDPAToPATest,
testing::ValuesIn(element_types));

0 comments on commit f8611e9

Please sign in to comment.