diff --git a/samples/cpp/text_generation/speculative_decoding_lm.cpp b/samples/cpp/text_generation/speculative_decoding_lm.cpp index 8a34bbf8a9..94de7de4a4 100644 --- a/samples/cpp/text_generation/speculative_decoding_lm.cpp +++ b/samples/cpp/text_generation/speculative_decoding_lm.cpp @@ -21,7 +21,7 @@ int main(int argc, char* argv[]) try { std::string main_model_path = argv[1]; std::string draft_model_path = argv[2]; std::string prompt = argv[3]; - + // User can run main and draft model on different devices. // Please, set device for main model in `LLMPipeline` constructor and in in `ov::genai::draft_model` for draft. std::string main_device = "CPU", draft_device = "CPU"; diff --git a/src/cpp/include/openvino/genai/generation_handle.hpp b/src/cpp/include/openvino/genai/generation_handle.hpp index 6619e3e012..71ec815e61 100644 --- a/src/cpp/include/openvino/genai/generation_handle.hpp +++ b/src/cpp/include/openvino/genai/generation_handle.hpp @@ -73,8 +73,6 @@ class GenerationStream; class OPENVINO_GENAI_EXPORTS GenerationHandleImpl { std::shared_ptr m_generation_stream; ov::genai::GenerationConfig m_sampling_params; - - bool is_dropped(); public: GenerationHandleImpl(std::shared_ptr generation_stream, const ov::genai::GenerationConfig& sampling_params) : @@ -90,6 +88,7 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl { GenerationStatus get_status(); bool can_read(); + bool is_dropped(); void drop(); diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 788da2b015..45c1d0f630 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -1,6 +1,9 @@ // Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 +#include +#include + #include "text_callback_streamer.hpp" #include "continuous_batching_impl.hpp" #include "utils.hpp" @@ -261,6 +264,9 @@ std::vector ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector& input_ids, const std::vector& sampling_params, const StreamerVariant& streamer) { + ManualTimer generate_timer("generate()"); + generate_timer.start(); + OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request"); OPENVINO_ASSERT(input_ids.size() == sampling_params.size()); @@ -299,8 +305,38 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector has_active_request = has_non_finished_requests(); + GenerationHandle& generation = generations.at(0); + + // create variables to make optimal thread-safe streaming + std::mutex mutex; + std::unique_lock lock(mutex); + std::condition_variable cv; + + // define stream token lambda to use in `t_stream` + auto stream_tokens = [&generation, &streamer_ptr, &has_active_request, &cv, &lock]() { + while (!generation->is_dropped() && (has_active_request || streamer_ptr && generation->can_read())) { + // waiting for any tokens or request finishing + cv.wait(lock, [&generation, &has_active_request]{ return generation->can_read() || !has_active_request; }); + if (streamer_ptr && generation->can_read()) { + std::unordered_map token = generation->back(); + for (const auto& gen_token : token.begin()->second.generated_ids) { + if (streamer_ptr->put(gen_token)) { + generation->drop(); + cv.notify_all(); + break; + } + } + } + }; + }; + + // to define streaming thread + std::thread t_stream([&stream_tokens] { + stream_tokens(); + }); + + while (!generation->is_dropped() && has_active_request) { try { const auto infer_start = std::chrono::steady_clock::now(); step(); @@ -314,27 +350,23 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorcan_read()) { - std::unordered_map token = generation->back(); - for (const auto& gen_token : token.begin()->second.generated_ids) { - continue_generation = !streamer_ptr->put(gen_token); - if (!continue_generation) { - generation->drop(); - break; - } - } - } + if (t_stream.joinable()) { + t_stream.join(); } if (streamer_ptr) { // push streamer's cache streamer_ptr->end(); } - if (!continue_generation) { + if (generation->is_dropped()) { drop_requests(); } else { OPENVINO_ASSERT(m_requests.empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); @@ -378,6 +410,8 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector + #include "prompt_lookup_impl.hpp" #include "text_callback_streamer.hpp" @@ -109,38 +111,61 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vectorget_awaiting_requests(); - bool continue_generation = true; - while (has_non_finished_requests() && continue_generation) { + std::atomic continue_streaming = true, has_active_request = has_non_finished_requests(); + auto& generation = generations.at(0); + + // create variables to make optimal thread-safe streaming + std::mutex mutex; + std::unique_lock lock(mutex); + std::condition_variable cv; + + // define stream token lambda to use in `t_stream` + auto stream_tokens = [&generation, &streamer_ptr, &has_active_request, &cv, &lock]() { + while (!generation->is_dropped() && (has_active_request || streamer_ptr && generation->can_read())) { + // waiting for any tokens or request finishing + cv.wait(lock, [&generation, &has_active_request]{ return generation->can_read() || !has_active_request; }); + + if (streamer_ptr && generation->can_read()) { + std::unordered_map token = generation->back(); + for (const auto& gen_token : token.begin()->second.generated_ids) { + if (streamer_ptr->put(gen_token)) { + generation->drop(); + cv.notify_all(); + break; + } + } + } + }; + }; + + // to define streaming thread + std::thread t_stream([&stream_tokens] { + stream_tokens(); + }); + + while (continue_streaming && has_active_request) { try { + const auto infer_start = std::chrono::steady_clock::now(); step(); } catch (...) { drop_requests(); // remove all requests from pipeline state in case of exception + has_active_request = false; + cv.notify_all(); throw; } - if (streamer_ptr) { - // not generated tokens like several prompt phase - auto& generation = generations.at(0); - if (!generation->can_read()) { - continue; - } - std::unordered_map token = generation->back(); - OPENVINO_ASSERT(1 <= token.size()); - OPENVINO_ASSERT(1 <= token.begin()->second.generated_ids.size()); - for (const auto& gen_token : token.begin()->second.generated_ids) { - continue_generation = !streamer_ptr->put(gen_token); - if (!continue_generation) { - generation->drop(); - break; - } - } - } + has_active_request = has_non_finished_requests(); + cv.notify_all(); + } + + if (t_stream.joinable()) { + t_stream.join(); } if (streamer_ptr) { // push streamer's cache streamer_ptr->end(); } - if (!continue_generation) { + if (generation->is_dropped()) { drop_requests(); } else { OPENVINO_ASSERT(m_pipeline->is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); @@ -177,6 +202,7 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vector + #include "text_callback_streamer.hpp" #include "speculative_decoding_impl.hpp" #include "utils.hpp" @@ -235,36 +237,62 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< } auto all_requests = get_awaiting_requests(); - bool continue_generation = true; - while (has_non_finished_requests() && continue_generation) { + std::atomic has_active_request = has_non_finished_requests(); + auto& generation = main_generations.at(0); + + // create variables to make optimal thread-safe streaming + std::mutex mutex; + std::unique_lock lock(mutex); + std::condition_variable cv; + + // define stream token lambda to use in `t_stream` + auto stream_tokens = [&generation, &streamer_ptr, &has_active_request, &cv, &lock]() { + while (!generation->is_dropped() && (has_active_request || streamer_ptr && generation->can_read())) { + // waiting for any tokens or request finishing + cv.wait(lock, [&generation, &has_active_request]{ return generation->can_read() || !has_active_request; }); + + if (streamer_ptr && generation->can_read()) { + std::unordered_map token = generation->back(); + for (const auto& gen_token : token.begin()->second.generated_ids) { + if (streamer_ptr->put(gen_token)) { + generation->drop(); + cv.notify_all(); + break; + } + } + } + }; + }; + + // to define streaming thread + std::thread t_stream([&stream_tokens] { + stream_tokens(); + }); + + while (!generation->is_dropped() && has_active_request) { try { + const auto infer_start = std::chrono::steady_clock::now(); step(); } catch (...) { drop_requests(); // remove all requests from pipeline state in case of exception + has_active_request = false; + cv.notify_all(); throw; } - if (streamer_ptr) { - auto& main_generation = main_generations.at(0); - // not generated tokens like several prompt phase - if (!main_generation->can_read()) { - continue; - } - std::unordered_map token = main_generation->back(); - for (const auto& gen_token : token.begin()->second.generated_ids) { - continue_generation = !streamer_ptr->put(gen_token); - if (!continue_generation) { - main_generation->drop(); - break; - } - } - } + has_active_request = has_non_finished_requests(); + cv.notify_all(); + } + + // waiting for competion of streaming + if (t_stream.joinable()) { + t_stream.join(); } if (streamer_ptr) { // push streamer's cache streamer_ptr->end(); } - if (!continue_generation) { + if (generation->is_dropped()) { drop_requests(); } else { OPENVINO_ASSERT(is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); @@ -301,6 +329,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< } OPENVINO_ASSERT(results.size() == input_ids.size()); + generate_timer.end(); return results; } diff --git a/src/cpp/src/synchronized_queue.hpp b/src/cpp/src/synchronized_queue.hpp index 70883bcae7..005bca4eb8 100644 --- a/src/cpp/src/synchronized_queue.hpp +++ b/src/cpp/src/synchronized_queue.hpp @@ -13,6 +13,7 @@ class SynchronizedQueue std::queue m_queue; std::mutex m_mutex; std::condition_variable m_cv; + size_t m_taken_element_cnt = 0; public: SynchronizedQueue() = default; @@ -22,7 +23,8 @@ class SynchronizedQueue T back() { std::unique_lock lock(m_mutex); - m_cv.wait(lock, [this]{return !m_queue.empty();}); + m_cv.wait(lock, [this]{return !m_queue.empty(); }); + m_taken_element_cnt = m_queue.size(); return m_queue.back(); } @@ -44,4 +46,9 @@ class SynchronizedQueue std::unique_lock lock(m_mutex); return m_queue.empty(); } + + bool full() { + std::unique_lock lock(m_mutex); + return m_taken_element_cnt == m_queue.size(); + } };