diff --git a/.github/workflows/causal_lm_cpp.yml b/.github/workflows/causal_lm_cpp.yml
index 5fc4617f2c..b967ed9e1a 100644
--- a/.github/workflows/causal_lm_cpp.yml
+++ b/.github/workflows/causal_lm_cpp.yml
@@ -836,6 +836,36 @@ jobs:
<<< $'Who drew this painting?\nWhen did the painter live?'
timeout-minutes: 4
+ visual_language_chat_sample-ubuntu-qwen2vl:
+ runs-on: ubuntu-22.04-16-cores
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: recursive
+ - uses: actions/setup-python@v4
+ with:
+ python-version: 3.11
+ - uses: ./.github/actions/install_openvino
+ with:
+ ov_link: ${{ env.l_u22_ov_link }}
+ - uses: ./.github/actions/build_app
+ with:
+ build_target: 'visual_language_chat py_openvino_genai'
+ - uses: ./.github/actions/install_python_deps
+ - name: Download and convert Qwen2VL model
+ run: |
+ source ./ov/setupvars.sh
+ optimum-cli export openvino --model Qwen/Qwen2-VL-2B-Instruct ./qwen2_vl_2b_ov/ --trust-remote-code
+ - name: Download images
+ run: |
+ wget https://llava-vl.github.io/static/images/monalisa.jpg
+ - name: Run visual_language_chat C++ sample - Qwen2VL
+ run: >
+ source ./ov/setupvars.sh
+ && ./build/samples/cpp/visual_language_chat/visual_language_chat ./qwen2_vl_2b_ov/ monalisa.jpg
+ <<< $'Who drew this painting?\nWhen did the painter live?'
+ timeout-minutes: 4
+
cpp-continuous-batching-ubuntu:
runs-on: ubuntu-20.04-8-cores
defaults:
diff --git a/SUPPORTED_MODELS.md b/SUPPORTED_MODELS.md
index 9487c715d9..f79234489d 100644
--- a/SUPPORTED_MODELS.md
+++ b/SUPPORTED_MODELS.md
@@ -362,6 +362,17 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
+
+ Qwen2-VL |
+ Qwen2-VL |
+ Not supported |
+
+
+ |
+
diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp
index 424349f5aa..22fb0b1c02 100644
--- a/src/cpp/src/lm_encoding.cpp
+++ b/src/cpp/src/lm_encoding.cpp
@@ -29,6 +29,23 @@ void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention
}
}
+void update_3d_position_ids(ov::Tensor&& position_ids, const ov::Tensor& attention_mask, const int64_t rope_delta) {
+ const size_t batch_size = attention_mask.get_shape().at(0);
+ const size_t sequence_length = attention_mask.get_shape().at(1);
+ const size_t thw_dim_size = 3;
+
+ position_ids.set_shape({thw_dim_size, batch_size, 1});
+ int64_t* position_ids_data = position_ids.data();
+
+ int64_t pos_id = static_cast(sequence_length) - 1 + rope_delta;
+
+ for (size_t batch = 0; batch < batch_size; batch++) {
+ for (size_t dim = 0; dim < thw_dim_size; ++dim) {
+ position_ids_data[dim * batch_size + batch] = pos_id;
+ }
+ }
+}
+
void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector next_beams) {
ov::Tensor original_mask{ov::element::i64, attention_mask.get_shape()};
ov::Shape original_shape = original_mask.get_shape();
@@ -58,7 +75,8 @@ std::pair> get_lm_encoded_results(
Sampler& sampler,
std::vector sequence_groups,
std::optional position_ids,
- std::optional m_embedding
+ std::optional m_embedding,
+ std::optional rope_delta
) {
std::vector generations;
for (SequenceGroup::Ptr sequence_group : sequence_groups) {
@@ -196,7 +214,11 @@ std::pair> get_lm_encoded_results(
update_attention_mask_with_beams(m_llm.get_tensor("attention_mask"), next_beams);
if (position_ids.has_value()) {
- update_position_ids(m_llm.get_tensor("position_ids"), m_llm.get_tensor("attention_mask"));
+ if (position_ids->get_shape().size() == 3 && rope_delta.has_value()) {
+ update_3d_position_ids(m_llm.get_tensor("position_ids"), m_llm.get_tensor("attention_mask"), rope_delta.value());
+ } else {
+ update_position_ids(m_llm.get_tensor("position_ids"), m_llm.get_tensor("attention_mask"));
+ }
}
m_llm.set_tensor("beam_idx", ov::Tensor{ov::element::i32, {total_num_tokens}, next_beams.data()});
diff --git a/src/cpp/src/lm_encoding.hpp b/src/cpp/src/lm_encoding.hpp
index c31cffb9bc..56f6db5227 100644
--- a/src/cpp/src/lm_encoding.hpp
+++ b/src/cpp/src/lm_encoding.hpp
@@ -10,7 +10,7 @@ namespace genai {
std::pair> get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask,
const std::shared_ptr& streamer_ptr, Sampler& sampler, std::vector sequence_groups,
- std::optional position_ids, std::optional m_embedding);
+ std::optional position_ids, std::optional m_embedding, std::optional rope_delta = std::nullopt);
}
}
diff --git a/src/cpp/src/tokenizer.cpp b/src/cpp/src/tokenizer.cpp
index 2f18e87839..b22f103577 100644
--- a/src/cpp/src/tokenizer.cpp
+++ b/src/cpp/src/tokenizer.cpp
@@ -575,13 +575,18 @@ class Tokenizer::TokenizerImpl {
{"slice", slice_callable},
};
+ std::string result;
try {
- return tpl.RenderAsString(params).value();
+ result = tpl.RenderAsString(params).value();
} catch (const std::exception& error) {
OPENVINO_THROW("Chat template for the current model is not supported by Jinja2Cpp. "
"Please apply template manually to your prompt before calling generate. "
"For example: user{user_prompt}model");
}
+ OPENVINO_ASSERT(!result.empty(), "Applied chat template resulted in an empty string. "
+ "Please check the chat template or apply template manually to your prompt before calling generate."
+ "For example: user{user_prompt}model");
+ return result;
}
void set_chat_template(const std::string& chat_template) {
diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp
index 9fecd037c6..4f3812862c 100644
--- a/src/cpp/src/visual_language/inputs_embedder.cpp
+++ b/src/cpp/src/visual_language/inputs_embedder.cpp
@@ -52,6 +52,12 @@ class InputsEmbedder::IInputsEmbedder {
public:
virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector& images, ov::genai::VLMPerfMetrics& metrics) = 0;
+ virtual std::pair> get_position_ids(const size_t inputs_embeds_size, const size_t history_size) {
+ ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds_size }};
+ std::iota(position_ids.data(), position_ids.data() + position_ids.get_size(), history_size);
+ return {position_ids, std::nullopt};
+ }
+
EmbeddingsModel get_embedding_model() const {
return m_embedding;
}
@@ -1157,6 +1163,408 @@ class InputsEmbedderInternVLChat : public InputsEmbedder::IInputsEmbedder {
}
};
+class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder {
+ // A model for merging image embeddings (hidden states), rotary_pos_emb and attension_mask.
+ // Inputs:
+ // - hidden_states: [N, embed_dim]
+ // - rotary_pos_emb: [?, 40]
+ // - attention_mask: [1, ?, ?]
+ // Output: [N, hidden_size]
+ ov::InferRequest m_vision_embeddings_merger;
+
+ ov::Tensor m_position_ids;
+ int64_t m_rope_delta = 0;
+
+public:
+ InputsEmbedderQwen2VL(
+ const VLMConfig& vlm_config,
+ const std::filesystem::path& model_dir,
+ const std::string& device,
+ const ov::AnyMap device_config) :
+ IInputsEmbedder(vlm_config, model_dir, device, device_config) {
+ auto compiled_model = utils::singleton_core().compile_model(model_dir / "openvino_vision_embeddings_merger_model.xml", device, device_config);
+ ov::genai::utils::print_compiled_model_properties(compiled_model, "VLM vision embeddings merger model");
+ m_vision_embeddings_merger = compiled_model.create_infer_request();
+ }
+
+ InputsEmbedderQwen2VL(
+ const VLMConfig& vlm_config,
+ const ModelsMap& models_map,
+ const Tokenizer& tokenizer,
+ const std::filesystem::path& config_dir_path,
+ const std::string& device,
+ const ov::AnyMap device_config) :
+ IInputsEmbedder(vlm_config, models_map, tokenizer, config_dir_path, device, device_config) {
+ m_vision_embeddings_merger = utils::singleton_core().compile_model(
+ get_model_weights_pair(models_map, "vision_embeddings_merger").first,
+ get_model_weights_pair(models_map, "vision_embeddings_merger").second,
+ device,
+ device_config
+ ).create_infer_request();
+ }
+
+ virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector& images, ov::genai::VLMPerfMetrics& metrics) override {
+ std::string formatted_prompt;
+
+ std::vector single_images = to_single_image_tensors(images);
+ std::vector image_embeds;
+ std::vector> images_grid_thw;
+ image_embeds.reserve(single_images.size());
+ images_grid_thw.reserve(single_images.size());
+
+ for (const auto& image : single_images) {
+ EncodedImage encoded_image = m_vision_encoder.encode(image);
+ ov::Tensor single_image_embeds = encoded_image.resized_source;
+ image_embeds.push_back(std::move(single_image_embeds));
+
+ size_t grid_t = 1;
+ size_t grid_h = encoded_image.resized_source_size.height;
+ size_t grid_w = encoded_image.resized_source_size.width;
+ images_grid_thw.push_back({grid_t, grid_h, grid_w});
+
+ size_t merge_length = std::pow(m_vision_encoder.m_processor_config.merge_size, 2);
+ size_t num_image_pad_tokens = grid_t * grid_h * grid_w / merge_length;
+
+ formatted_prompt += m_vlm_config.vision_start_token;
+ for (int i = 0; i < num_image_pad_tokens; i++) {
+ formatted_prompt += m_vlm_config.image_pad_token;
+ }
+ formatted_prompt += m_vlm_config.vision_end_token;
+ }
+ formatted_prompt += prompt;
+
+ // Adapted from Qwen/Qwen2-7B-Instruct
+ std::string chat_template_fallback = "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
+ ov::Tensor input_ids = get_encoded_input_ids(formatted_prompt, metrics, chat_template_fallback);
+ ov::Tensor text_embeds = m_embedding.infer(input_ids);
+
+ if (images.empty()) {
+ return text_embeds;
+ }
+
+ auto start_tokenizer_time = std::chrono::steady_clock::now();
+ ov::Tensor encoded_vision_start_token = m_tokenizer.encode(m_vlm_config.vision_start_token, ov::genai::add_special_tokens(false)).input_ids;
+ ov::Tensor encoded_image_pad_token = m_tokenizer.encode(m_vlm_config.image_pad_token, ov::genai::add_special_tokens(false)).input_ids;
+ auto end_tokenizer_time = std::chrono::steady_clock::now();
+ OPENVINO_ASSERT(metrics.raw_metrics.tokenization_durations.size() > 0);
+ metrics.raw_metrics.tokenization_durations[metrics.raw_metrics.tokenization_durations.size() - 1] += ov::genai::MicroSeconds(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time));
+ int64_t vision_start_token_id = encoded_vision_start_token.data()[encoded_vision_start_token.get_size() - 1];
+ int64_t image_pad_token_id = encoded_image_pad_token.data()[encoded_image_pad_token.get_size() - 1];
+
+ m_position_ids = create_position_ids(input_ids, images_grid_thw, vision_start_token_id);
+
+ int64_t position_ids_max_element = *std::max_element(m_position_ids.data(), m_position_ids.data() + m_position_ids.get_size());
+ m_rope_delta = position_ids_max_element + 1 - static_cast(input_ids.get_shape().at(1));
+
+ return merge_text_and_image_embeddings_qwen2vl(input_ids, text_embeds, image_embeds, images_grid_thw, image_pad_token_id);
+ }
+
+ virtual std::pair> get_position_ids(const size_t inputs_embeds_size, const size_t history_size) override {
+ if (history_size != 0) {
+ ov::Tensor position_ids{ov::element::i64, {3, 1, inputs_embeds_size}};
+ int64_t new_pos_id = static_cast(history_size + m_rope_delta);
+ for (size_t dim = 0; dim < 3; ++dim) {
+ int64_t* pos_data = position_ids.data() + dim * inputs_embeds_size;
+ std::iota(pos_data, pos_data + inputs_embeds_size, new_pos_id);
+ }
+ return {position_ids, m_rope_delta};
+ }
+ return {m_position_ids, m_rope_delta};
+ }
+
+ virtual void start_chat(const std::string& system_message) override {
+ IInputsEmbedder::start_chat(system_message);
+ m_position_ids = ov::Tensor();
+ m_rope_delta = 0;
+ }
+
+ virtual void finish_chat() override {
+ IInputsEmbedder::finish_chat();
+ m_position_ids = ov::Tensor();
+ m_rope_delta = 0;
+ }
+protected:
+ ov::Tensor merge_text_and_image_embeddings_qwen2vl(
+ const ov::Tensor& input_ids,
+ const ov::Tensor& text_embeds,
+ const std::vector& image_embeds,
+ const std::vector> images_grid_thw,
+ const int64_t image_pad_token_id
+ ) {
+ // Calculate cumulative sequence lengths for attention mask
+ std::vector cu_seqlens;
+ cu_seqlens.push_back(0);
+ int32_t cumsum = 0;
+ for (const auto& grid_thw : images_grid_thw) {
+ size_t slice_len = grid_thw.at(1) * grid_thw.at(2);
+ for (size_t t = 0; t < grid_thw.at(0); ++t) {
+ cumsum += slice_len;
+ cu_seqlens.push_back(cumsum);
+ }
+ }
+
+ // Create attention mask for vision embeddings merger model
+ size_t hidden_states_size = cumsum;
+ ov::Tensor attention_mask{ov::element::f32, {1, hidden_states_size, hidden_states_size}};
+ float* attention_mask_data = attention_mask.data();
+ std::fill_n(attention_mask_data, attention_mask.get_size(), -std::numeric_limits::infinity());
+
+ for (size_t i = 1; i < cu_seqlens.size(); ++i) {
+ size_t start = cu_seqlens[i-1];
+ size_t end = cu_seqlens[i];
+ for (size_t row = start; row < end; ++row) {
+ for (size_t col = start; col < end; ++col) {
+ attention_mask_data[row * hidden_states_size + col] = 0.0f;
+ }
+ }
+ }
+
+ // Concatenate image embeddings
+ ov::Tensor concatenated_images;
+ if (image_embeds.size() == 1) {
+ concatenated_images = image_embeds.at(0);
+ } else {
+ size_t total_length = 0;
+ for (const auto& embed : image_embeds) {
+ total_length += embed.get_shape().at(0);
+ }
+ size_t hidden_dim = image_embeds.at(0).get_shape().at(1);
+
+ concatenated_images = ov::Tensor(image_embeds.at(0).get_element_type(), {total_length, hidden_dim});
+ float* concat_data = concatenated_images.data();
+
+ size_t offset = 0;
+ for (const auto& embed : image_embeds) {
+ size_t embed_size = embed.get_shape().at(0) * embed.get_shape().at(1);
+ std::memcpy(concat_data + offset, embed.data(), embed.get_byte_size());
+ offset += embed_size;
+ }
+ }
+
+ ov::Tensor rotary_pos_emb = get_rotary_pos_emb(images_grid_thw);
+
+ m_vision_embeddings_merger.set_tensor("hidden_states", concatenated_images);
+ m_vision_embeddings_merger.set_tensor("attention_mask", attention_mask);
+ m_vision_embeddings_merger.set_tensor("rotary_pos_emb", rotary_pos_emb);
+ m_vision_embeddings_merger.infer();
+ ov::Tensor processed_vision_embeds = m_vision_embeddings_merger.get_output_tensor();
+
+ ov::Tensor merged_embeds(text_embeds.get_element_type(), text_embeds.get_shape());
+ std::memcpy(merged_embeds.data(), text_embeds.data(), text_embeds.get_byte_size());
+
+ auto text_embeds_shape = text_embeds.get_shape();
+ size_t batch_size = text_embeds_shape.at(0);
+ size_t seq_length = text_embeds_shape.at(1);
+ size_t hidden_size = text_embeds_shape.at(2);
+
+ const int64_t* input_ids_data = input_ids.data();
+ float* merged_embeds_data = merged_embeds.data();
+ const float* vision_embeds_data = processed_vision_embeds.data();
+
+ size_t vision_embed_idx = 0;
+ for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
+ for (size_t seq_idx = 0; seq_idx < seq_length; ++seq_idx) {
+ size_t flat_idx = batch_idx * seq_length + seq_idx;
+ if (input_ids_data[flat_idx] == image_pad_token_id) {
+ std::copy_n(
+ vision_embeds_data + vision_embed_idx * hidden_size,
+ hidden_size,
+ merged_embeds_data + flat_idx * hidden_size
+ );
+ ++vision_embed_idx;
+ }
+ }
+ }
+ return merged_embeds;
+ }
+
+ ov::Tensor get_rotary_pos_emb(const std::vector>& grids_thw) {
+ const size_t spatial_merge_size = m_vision_encoder.m_processor_config.merge_size;
+
+ std::vector> all_pos_ids;
+ size_t total_positions = 0;
+ size_t max_grid_size = 0;
+
+ for (const auto& grid_thw : grids_thw) {
+ size_t t = grid_thw.at(0);
+ size_t h = grid_thw.at(1);
+ size_t w = grid_thw.at(2);
+
+ total_positions += t * h * w;
+ max_grid_size = std::max({max_grid_size, h, w});
+
+ // Create height position IDs
+ std::vector hpos_ids(h * w);
+ for (size_t hi = 0; hi < h; ++hi) {
+ for (size_t wi = 0; wi < w; ++wi) {
+ size_t idx = hi * w + wi;
+ hpos_ids[idx] = hi;
+ }
+ }
+
+ // Reshape hpos_ids according to spatial merge size
+ std::vector reshaped_hpos;
+ size_t h_blocks = h / spatial_merge_size;
+ size_t w_blocks = w / spatial_merge_size;
+ reshaped_hpos.reserve(h * w);
+
+ for (size_t hb = 0; hb < h_blocks; ++hb) {
+ for (size_t wb = 0; wb < w_blocks; ++wb) {
+ for (size_t hs = 0; hs < spatial_merge_size; ++hs) {
+ for (size_t ws = 0; ws < spatial_merge_size; ++ws) {
+ reshaped_hpos.push_back(hb * spatial_merge_size + hs);
+ }
+ }
+ }
+ }
+
+ // Create width position IDs
+ std::vector wpos_ids(h * w);
+ for (size_t hi = 0; hi < h; ++hi) {
+ for (size_t wi = 0; wi < w; ++wi) {
+ size_t idx = hi * w + wi;
+ wpos_ids[idx] = wi;
+ }
+ }
+
+ // Reshape wpos_ids according to spatial merge size
+ std::vector reshaped_wpos;
+ reshaped_wpos.reserve(h * w);
+
+ for (size_t hb = 0; hb < h_blocks; ++hb) {
+ for (size_t wb = 0; wb < w_blocks; ++wb) {
+ for (size_t hs = 0; hs < spatial_merge_size; ++hs) {
+ for (size_t ws = 0; ws < spatial_merge_size; ++ws) {
+ reshaped_wpos.push_back(wb * spatial_merge_size + ws);
+ }
+ }
+ }
+ }
+
+ // Stack and repeat for each t
+ for (size_t i = 0; i < t; ++i) {
+ for (size_t j = 0; j < reshaped_hpos.size(); ++j) {
+ all_pos_ids.push_back({reshaped_hpos[j], reshaped_wpos[j]});
+ }
+ }
+ }
+
+ // Calculate rotary embeddings for max_grid_size
+ const size_t dim = 1280 / 16 / 2; // config.vision_config.embed_dim / self.config.vision_config.num_heads / 2
+ const float theta = 10000.0f;
+
+ std::vector inv_freq(dim / 2);
+ for (size_t i = 0; i < dim / 2; ++i) {
+ inv_freq[i] = 1.0f / std::pow(theta, static_cast(i) / static_cast(dim / 2));
+ }
+
+ std::vector> freqs(max_grid_size);
+ for (size_t i = 0; i < max_grid_size; ++i) {
+ freqs[i].resize(dim / 2);
+ for (size_t j = 0; j < dim / 2; ++j) {
+ freqs[i][j] = static_cast(i) * inv_freq[j];
+ }
+ }
+
+ ov::Tensor rotary_pos_emb(ov::element::f32, {all_pos_ids.size(), dim});
+ float* output_data = rotary_pos_emb.data();
+
+ for (size_t i = 0; i < all_pos_ids.size(); ++i) {
+ const auto& pos = all_pos_ids.at(i);
+ size_t h_idx = pos.at(0);
+ size_t w_idx = pos.at(1);
+ std::copy_n(freqs[h_idx].begin(), dim / 2, output_data + i * dim);
+ std::copy_n(freqs[w_idx].begin(), dim / 2, output_data + i * dim + dim / 2);
+ }
+
+ return rotary_pos_emb;
+ }
+
+ ov::Tensor create_position_ids(
+ const ov::Tensor& input_ids_tensor,
+ const std::vector>& images_grid_thw,
+ const int64_t vision_start_token_id
+ ) {
+ const size_t spatial_merge_size = m_vision_encoder.m_processor_config.merge_size;
+
+ const int64_t* input_ids = input_ids_tensor.data();
+ size_t batch_size = input_ids_tensor.get_shape().at(0);
+ size_t seq_len = input_ids_tensor.get_shape().at(1);
+
+ std::vector vision_start_indices;
+ for (size_t i = 0; i < seq_len; ++i) {
+ if (input_ids[i] == vision_start_token_id) {
+ vision_start_indices.push_back(i);
+ }
+ }
+
+ ov::Tensor position_ids{ov::element::i64, {3, batch_size, seq_len}};
+ int64_t* pos_data = position_ids.data();
+
+ size_t st = 0;
+ int64_t next_pos = 0;
+ size_t grid_idx = 0;
+
+ for (size_t i = 0; i < vision_start_indices.size(); ++i) {
+ size_t ed = vision_start_indices.at(i);
+
+ // Process text tokens before image
+ if (st < ed) {
+ for (size_t pos = st; pos < ed; ++pos) {
+ pos_data[pos] = next_pos; // temporal
+ pos_data[seq_len + pos] = next_pos; // height
+ pos_data[2 * seq_len + pos] = next_pos; // width
+ next_pos++;
+ }
+ }
+
+ // Process image start token
+ pos_data[ed] = next_pos; // temporal
+ pos_data[seq_len + ed] = next_pos; // height
+ pos_data[2 * seq_len + ed] = next_pos; // width
+ next_pos++;
+ ed++;
+
+ // Process image token with grid
+ if (grid_idx < images_grid_thw.size()) {
+ const auto& grid = images_grid_thw.at(grid_idx);
+ size_t llm_grid_h = grid.at(1) / spatial_merge_size;
+ size_t llm_grid_w = grid.at(2) / spatial_merge_size;
+ size_t ed_image = ed + llm_grid_h * llm_grid_w;
+
+ // Fill temporal dimension
+ std::fill_n(pos_data + ed, llm_grid_h * llm_grid_w, next_pos);
+
+ // Fill height and width dimensions
+ int64_t* height_data = pos_data + seq_len + ed;
+ int64_t* width_data = pos_data + 2 * seq_len + ed;
+ for (size_t h = 0; h < llm_grid_h; ++h) {
+ std::fill_n(height_data + h * llm_grid_w, llm_grid_w, next_pos + h);
+ for (size_t w = 0; w < llm_grid_w; ++w) {
+ width_data[h * llm_grid_w + w] = next_pos + w;
+ }
+ }
+
+ next_pos += std::max(llm_grid_h, llm_grid_w);
+ st = ed_image;
+ grid_idx++;
+ }
+ }
+
+ // Process remaining text tokens
+ if (st < seq_len) {
+ for (size_t pos = st; pos < seq_len; ++pos) {
+ pos_data[pos] = next_pos; // temporal
+ pos_data[seq_len + pos] = next_pos; // height
+ pos_data[2 * seq_len + pos] = next_pos; // width
+ next_pos++;
+ }
+ }
+
+ return position_ids;
+ }
+};
+
InputsEmbedder::InputsEmbedder(const VLMConfig& vlm_config,
const std::filesystem::path& model_dir,
const std::string& device,
@@ -1169,6 +1577,8 @@ InputsEmbedder::InputsEmbedder(const VLMConfig& vlm_config,
m_impl = std::make_shared(vlm_config, model_dir, device, device_config);
} else if (vlm_config.model_type == VLMModelType::INTERNVL_CHAT) {
m_impl = std::make_shared(vlm_config, model_dir, device, device_config);
+ } else if (vlm_config.model_type == VLMModelType::QWEN2_VL) {
+ m_impl = std::make_shared(vlm_config, model_dir, device, device_config);
} else {
OPENVINO_THROW("Unsupported model type in VLM InputsEmbedder class. Please, create feature request on new model support");
}
@@ -1188,6 +1598,8 @@ InputsEmbedder::InputsEmbedder(const VLMConfig& vlm_config,
m_impl = std::make_shared(vlm_config, models_map, tokenizer, config_dir_path, device, device_config);
} else if (vlm_config.model_type == VLMModelType::INTERNVL_CHAT) {
m_impl = std::make_shared(vlm_config, models_map, tokenizer, config_dir_path, device, device_config);
+ } else if (vlm_config.model_type == VLMModelType::QWEN2_VL) {
+ m_impl = std::make_shared(vlm_config, models_map, tokenizer, config_dir_path, device, device_config);
} else {
OPENVINO_THROW("Unsupported model type in VLM InputsEmbedder class. Please, create feature request on new model support");
}
@@ -1197,6 +1609,10 @@ ov::Tensor InputsEmbedder::get_inputs_embeds(const std::string& prompt, const st
return m_impl->get_inputs_embeds(prompt, images, metrics);
}
+std::pair> InputsEmbedder::get_position_ids(const size_t inputs_embeds_size, const size_t history_size) {
+ return m_impl->get_position_ids(inputs_embeds_size, history_size);
+}
+
EmbeddingsModel InputsEmbedder::get_embedding_model() const {
return m_impl->get_embedding_model();
}
diff --git a/src/cpp/src/visual_language/inputs_embedder.hpp b/src/cpp/src/visual_language/inputs_embedder.hpp
index 56fa488465..223d090b22 100644
--- a/src/cpp/src/visual_language/inputs_embedder.hpp
+++ b/src/cpp/src/visual_language/inputs_embedder.hpp
@@ -34,6 +34,9 @@ class InputsEmbedder {
// compute input embedding for prompt and multiple images
ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector& images, ov::genai::VLMPerfMetrics& metrics);
+ // compute position ids for language model input
+ std::pair> get_position_ids(const size_t inputs_embeds_size, const size_t history_size);
+
// returns embedding model which converts token_id(s) to embedding vectors
EmbeddingsModel get_embedding_model() const;
@@ -65,6 +68,7 @@ class InputsEmbedder {
friend class InputsEmbedderLLaVA;
friend class InputsEmbedderLLaVANext;
friend class InputsEmbedderInternVLChat;
+ friend class InputsEmbedderQwen2VL;
};
} // namespace ov::genai
diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp
index cf2f033e29..95e3064548 100644
--- a/src/cpp/src/visual_language/pipeline.cpp
+++ b/src/cpp/src/visual_language/pipeline.cpp
@@ -208,8 +208,9 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, { 1, history_size + inputs_embeds_size }};
std::fill_n(new_atten_mask.data(), new_atten_mask.get_size(), 1);
- ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds_size }};
- std::iota(position_ids.data(), position_ids.data() + position_ids.get_size(), history_size);
+ ov::Tensor position_ids;
+ std::optional rope_delta;
+ std::tie(position_ids, rope_delta) = m_inputs_embedder->get_position_ids(inputs_embeds_size, history_size);
if (m_sampler.get_seed() != generation_config.rng_seed) {
m_sampler.set_seed(generation_config.rng_seed);
@@ -218,7 +219,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
ov::genai::EncodedResults encoded_result;
std::optional last_disappeared_token;
std::tie(encoded_result, last_disappeared_token) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests,
- position_ids, m_embedding);
+ position_ids, m_embedding, rope_delta);
auto decode_start_time = std::chrono::steady_clock::now();
VLMDecodedResults decoded;
diff --git a/src/cpp/src/visual_language/processor_config.cpp b/src/cpp/src/visual_language/processor_config.cpp
index e88c580fa0..f790c58912 100644
--- a/src/cpp/src/visual_language/processor_config.cpp
+++ b/src/cpp/src/visual_language/processor_config.cpp
@@ -33,7 +33,7 @@ ov::genai::ProcessorConfig::ProcessorConfig(const std::filesystem::path& json_pa
crop_size_height = parsed.at("crop_size").at("height");
crop_size_width = parsed.at("crop_size").at("width");
}
- if (parsed.contains("size")) {
+ if (parsed.contains("size") && parsed.at("size").contains("shortest_edge")) {
size_shortest_edge = parsed.at("size").at("shortest_edge");
}
@@ -41,4 +41,10 @@ ov::genai::ProcessorConfig::ProcessorConfig(const std::filesystem::path& json_pa
if (parsed.contains("image_grid_pinpoints")) {
image_grid_pinpoints = parsed.at("image_grid_pinpoints").get>>();
}
+
+ // Setting qwen2vl config params
+ read_json_param(parsed, "min_pixels", min_pixels);
+ read_json_param(parsed, "max_pixels", max_pixels);
+ read_json_param(parsed, "temporal_patch_size", temporal_patch_size);
+ read_json_param(parsed, "merge_size", merge_size);
}
diff --git a/src/cpp/src/visual_language/processor_config.hpp b/src/cpp/src/visual_language/processor_config.hpp
index 43959dffeb..1d40e091a9 100644
--- a/src/cpp/src/visual_language/processor_config.hpp
+++ b/src/cpp/src/visual_language/processor_config.hpp
@@ -45,6 +45,12 @@ class ProcessorConfig {
// llava-next specific config params
std::vector> image_grid_pinpoints{{336, 672}, {672, 336}, {672, 672}, {1008, 336}, {336, 1008}};
+ // qwen2vl specific params
+ size_t min_pixels = 3136;
+ size_t max_pixels = 12845056;
+ size_t temporal_patch_size = 2;
+ size_t merge_size = 2;
+
/// @brief Default constructor
ProcessorConfig() = default;
/// @brief Construct ProcessorConfig from values in json_path.
diff --git a/src/cpp/src/visual_language/vision_encoder.cpp b/src/cpp/src/visual_language/vision_encoder.cpp
index 88ab5eac53..4a5179fdd0 100644
--- a/src/cpp/src/visual_language/vision_encoder.cpp
+++ b/src/cpp/src/visual_language/vision_encoder.cpp
@@ -644,6 +644,158 @@ ov::Tensor get_pixel_values_internvl(const ov::Tensor& image, const ProcessorCon
}
return output_tensor;
}
+
+ImageSize smart_resize_qwen2vl(size_t height, size_t width, size_t factor, size_t min_pixels, size_t max_pixels) {
+ if (height < factor || width < factor) {
+ OPENVINO_THROW("Height or width must be larger than factor");
+ }
+ if (std::max(height, width) / std::min(height, width) > 200) {
+ OPENVINO_THROW("Absolute aspect ratio must be smaller than 200");
+ }
+
+ size_t h_bar = std::round(static_cast(height) / factor) * factor;
+ size_t w_bar = std::round(static_cast(width) / factor) * factor;
+
+ if (h_bar * w_bar > max_pixels) {
+ double beta = std::sqrt((height * width) / static_cast(max_pixels));
+ h_bar = std::floor(height / beta / factor) * factor;
+ w_bar = std::floor(width / beta / factor) * factor;
+ } else if (h_bar * w_bar < min_pixels) {
+ double beta = std::sqrt(min_pixels / static_cast(height * width));
+ h_bar = std::ceil(height * beta / factor) * factor;
+ w_bar = std::ceil(width * beta / factor) * factor;
+ }
+
+ return ImageSize{h_bar, w_bar};
+}
+
+ov::Tensor reshape_image_patches_qwen2vl(
+ const ov::Tensor& patches,
+ const size_t grid_t,
+ const size_t grid_h,
+ const size_t grid_w,
+ const size_t channel,
+ const size_t temporal_patch_size,
+ const size_t patch_size,
+ const size_t spatial_merge_size
+) {
+ ov::Shape output_shape{
+ grid_t,
+ temporal_patch_size,
+ channel,
+ grid_h / spatial_merge_size,
+ spatial_merge_size,
+ patch_size,
+ grid_w / spatial_merge_size,
+ spatial_merge_size,
+ patch_size
+ };
+
+ ov::Tensor reshaped_patches(patches.get_element_type(), output_shape);
+
+ const float* input_data = patches.data();
+ float* output_data = reshaped_patches.data();
+
+ size_t input_idx = 0;
+
+ for (size_t gt = 0; gt < output_shape.at(0); ++gt) {
+ for (size_t tp = 0; tp < output_shape.at(1); ++tp) {
+ for (size_t c = 0; c < output_shape.at(2); ++c) {
+ for (size_t gh = 0; gh < output_shape.at(3); ++gh) {
+ for (size_t ms1 = 0; ms1 < output_shape.at(4); ++ms1) {
+ for (size_t p1 = 0; p1 < output_shape.at(5); ++p1) {
+ for (size_t gw = 0; gw < output_shape.at(6); ++gw) {
+ for (size_t ms2 = 0; ms2 < output_shape.at(7); ++ms2) {
+ for (size_t p2 = 0; p2 < output_shape.at(8); ++p2) {
+ size_t output_idx = gt;
+ output_idx = output_idx * output_shape.at(1) + tp;
+ output_idx = output_idx * output_shape.at(2) + c;
+ output_idx = output_idx * output_shape.at(3) + gh;
+ output_idx = output_idx * output_shape.at(4) + ms1;
+ output_idx = output_idx * output_shape.at(5) + p1;
+ output_idx = output_idx * output_shape.at(6) + gw;
+ output_idx = output_idx * output_shape.at(7) + ms2;
+ output_idx = output_idx * output_shape.at(8) + p2;
+
+ output_data[output_idx] = input_data[input_idx];
+ input_idx++;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return reshaped_patches;
+}
+
+ov::Tensor transpose_image_patches_qwen2vl(const ov::Tensor& reshaped_patches) {
+ // Input dimensions order: [0,1,2,3,4,5,6,7,8]
+ // Output dimensions order: [0,3,6,4,7,2,1,5,8]
+ auto input_shape = reshaped_patches.get_shape();
+
+ ov::Shape output_shape = {
+ input_shape.at(0), // grid_t
+ input_shape.at(3), // grid_h / spatial_merge_size
+ input_shape.at(6), // grid_w / spatial_merge_size
+ input_shape.at(4), // spatial_merge_size
+ input_shape.at(7), // spatial_merge_size
+ input_shape.at(2), // channel
+ input_shape.at(1), // temporal_patch_size
+ input_shape.at(5), // patch_size
+ input_shape.at(8) // patch_size
+ };
+
+ ov::Tensor transposed_patches(reshaped_patches.get_element_type(), output_shape);
+
+ const float* src = reshaped_patches.data();
+ float* dst = transposed_patches.data();
+
+ size_t shape_size = input_shape.size();
+ std::vector input_strides(shape_size);
+ std::vector output_strides(shape_size);
+
+ input_strides[shape_size - 1] = 1;
+ output_strides[shape_size - 1] = 1;
+ for(int i = 7; i >= 0; i--) {
+ input_strides[i] = input_strides[i+1] * input_shape[i+1];
+ output_strides[i] = output_strides[i+1] * output_shape[i+1];
+ }
+
+ size_t total_elements = reshaped_patches.get_size();
+ for(size_t idx = 0; idx < total_elements; idx++) {
+ size_t remaining = idx;
+ std::vector input_indices(shape_size);
+ for(int i = 0; i < shape_size; i++) {
+ input_indices[i] = remaining / input_strides[i];
+ remaining %= input_strides[i];
+ }
+
+ std::vector output_indices = {
+ input_indices.at(0),
+ input_indices.at(3),
+ input_indices.at(6),
+ input_indices.at(4),
+ input_indices.at(7),
+ input_indices.at(2),
+ input_indices.at(1),
+ input_indices.at(5),
+ input_indices.at(8)
+ };
+
+ size_t dst_idx = 0;
+ for(int i = 0; i < shape_size; i++) {
+ dst_idx += output_indices[i] * output_strides[i];
+ }
+
+ dst[dst_idx] = src[idx];
+ }
+
+ return transposed_patches;
+}
}
VisionEncoder::VisionEncoder(const std::filesystem::path& model_dir, const VLMModelType model_type, const std::string& device, const ov::AnyMap device_config) :
@@ -678,8 +830,10 @@ EncodedImage VisionEncoder::encode(const ov::Tensor& image, const ProcessorConfi
return encode_llava(image, config);
} else if (model_type == VLMModelType::LLAVA_NEXT) {
return encode_llava_next(image, config);
- } else if (model_type == VLMModelType::INTERNVL_CHAT) {
+ } else if (model_type == VLMModelType::INTERNVL_CHAT) {
return encode_internvl(image, config);
+ } else if (model_type == VLMModelType::QWEN2_VL) {
+ return encode_qwen2vl(image, config);
} else {
OPENVINO_THROW("Unsupported type of VisionEncoder");
}
@@ -753,3 +907,74 @@ EncodedImage VisionEncoder::encode_internvl(const ov::Tensor& image, const Proce
return {std::move(image_features), resized_source_size};
}
+
+EncodedImage VisionEncoder::encode_qwen2vl(const ov::Tensor& image, const ProcessorConfig& config) {
+ ov::Shape image_shape = image.get_shape();
+ auto original_height = image_shape.at(1);
+ auto original_width = image_shape.at(2);
+
+ ImageSize target_image_size = smart_resize_qwen2vl(
+ original_height,
+ original_width,
+ config.patch_size * config.merge_size,
+ config.min_pixels,
+ config.max_pixels
+ );
+
+ clip_image_u8 input_image = tensor_to_clip_image_u8(image);
+ clip_image_u8 resized_image;
+ bicubic_resize(input_image, resized_image, target_image_size.width, target_image_size.height);
+
+ clip_ctx ctx;
+ std::copy(config.image_mean.begin(), config.image_mean.end(), ctx.image_mean);
+ std::copy(config.image_std.begin(), config.image_std.end(), ctx.image_std);
+ clip_image_f32 normalized_image = clip_image_preprocess(ctx, resized_image);
+
+ ov::Tensor patches = clip_image_f32_to_tensor(normalized_image);
+
+ // For single patch tile it to match temporal_patch_size
+ if (patches.get_shape().at(0) == 1) {
+ auto orig_shape = patches.get_shape();
+ ov::Tensor tiled_patches(patches.get_element_type(),
+ {config.temporal_patch_size, orig_shape.at(1), orig_shape.at(2), orig_shape.at(3)});
+
+ for (size_t i = 0; i < config.temporal_patch_size; i++) {
+ std::memcpy(
+ tiled_patches.data() + i * patches.get_byte_size() / sizeof(float),
+ patches.data(),
+ patches.get_byte_size()
+ );
+ }
+ patches = std::move(tiled_patches);
+ }
+
+ auto patches_shape = patches.get_shape();
+ size_t channel = patches_shape.at(1);
+
+ size_t grid_t = patches_shape.at(0) / config.temporal_patch_size;
+ size_t grid_h = target_image_size.height / config.patch_size;
+ size_t grid_w = target_image_size.width / config.patch_size;
+
+ ov::Tensor reshaped_patches = reshape_image_patches_qwen2vl(
+ patches, grid_t, grid_h, grid_w, channel, config.temporal_patch_size, config.patch_size, config.merge_size
+ );
+ ov::Tensor transposed_patches = transpose_image_patches_qwen2vl(reshaped_patches);
+
+ ov::Shape flattened_patches_shape{
+ grid_t * grid_h * grid_w,
+ channel * config.temporal_patch_size * config.patch_size * config.patch_size
+ };
+ ov::Tensor flattened_patches(transposed_patches.get_element_type(), flattened_patches_shape);
+ std::memcpy(flattened_patches.data(), transposed_patches.data(), transposed_patches.get_byte_size());
+
+ m_vision_encoder.set_tensor("hidden_states", flattened_patches);
+ m_vision_encoder.infer();
+
+ const ov::Tensor& infer_output = m_vision_encoder.get_output_tensor();
+ ov::Tensor image_features(infer_output.get_element_type(), infer_output.get_shape());
+ std::memcpy(image_features.data(), infer_output.data(), infer_output.get_byte_size());
+
+ ImageSize resized_source_size{grid_h, grid_w};
+
+ return {std::move(image_features), resized_source_size};
+}
diff --git a/src/cpp/src/visual_language/vision_encoder.hpp b/src/cpp/src/visual_language/vision_encoder.hpp
index e25b875261..e725c06bf4 100644
--- a/src/cpp/src/visual_language/vision_encoder.hpp
+++ b/src/cpp/src/visual_language/vision_encoder.hpp
@@ -158,5 +158,9 @@ class VisionEncoder {
EncodedImage encode_internvl(
const ov::Tensor& image, const ProcessorConfig& config
);
+
+ EncodedImage encode_qwen2vl(
+ const ov::Tensor& image, const ProcessorConfig& config
+ );
};
}
diff --git a/src/cpp/src/visual_language/vlm_config.hpp b/src/cpp/src/visual_language/vlm_config.hpp
index ffaab49243..c70c757707 100644
--- a/src/cpp/src/visual_language/vlm_config.hpp
+++ b/src/cpp/src/visual_language/vlm_config.hpp
@@ -55,6 +55,13 @@ class VLMConfig {
/// @brief A string token denoting end of image embeddings for InternVL2 model.
std::string image_end_token = "";
+ /// @brief A string token denoting start of vision embeddings for Qwen2VL model.
+ std::string vision_start_token = "<|vision_start|>";
+ /// @brief A placeholder for image embeddings in text for Qwen2VL model.
+ std::string image_pad_token = "<|image_pad|>";
+ /// @brief A string token denoting end of vision embeddings for Qwen2VL model.
+ std::string vision_end_token = "<|vision_end|>";
+
/// @brief Default constructor.
VLMConfig() = default;
/// @brief Construct VLMConfig from values in json_path.
diff --git a/src/cpp/src/visual_language/vlm_model_type.hpp b/src/cpp/src/visual_language/vlm_model_type.hpp
index f882acb612..6f554fbf98 100644
--- a/src/cpp/src/visual_language/vlm_model_type.hpp
+++ b/src/cpp/src/visual_language/vlm_model_type.hpp
@@ -16,6 +16,7 @@ enum class VLMModelType {
LLAVA,
LLAVA_NEXT,
INTERNVL_CHAT,
+ QWEN2_VL,
};
inline VLMModelType to_vlm_model_type(const std::string& value) {
@@ -23,7 +24,8 @@ inline VLMModelType to_vlm_model_type(const std::string& value) {
{"minicpmv", VLMModelType::MINICPM},
{"llava", VLMModelType::LLAVA},
{"llava_next", VLMModelType::LLAVA_NEXT},
- {"internvl_chat", VLMModelType::INTERNVL_CHAT}
+ {"internvl_chat", VLMModelType::INTERNVL_CHAT},
+ {"qwen2_vl", VLMModelType::QWEN2_VL}
};
auto it = model_types_map.find(value);
diff --git a/tests/python_tests/test_tokenizer.py b/tests/python_tests/test_tokenizer.py
index 7980c2152e..445b779c3e 100644
--- a/tests/python_tests/test_tokenizer.py
+++ b/tests/python_tests/test_tokenizer.py
@@ -181,6 +181,12 @@ def test_apply_chat_template(model_tmp_path, chat_config: Tuple[str, Dict]):
print(f'ov_genai out: {ov_full_history_str}')
assert ov_full_history_str == hf_full_history_str
+ # Test throwing exception for empty rendered chat template
+ # Example: Qwen2-VL chat template
+ chat_template_for_empty_output = "{% if messages is string %}{{ messages }}{% else %}{% for content in messages %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}"
+ with pytest.raises(Exception):
+ ov_tokenizer.apply_chat_template(conversation, chat_template=chat_template_for_empty_output)
+
@pytest.mark.precommit
@pytest.mark.nightly