Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CB] Split token streaming and generation to different threads for all CB based pipelines #1544

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion samples/cpp/text_generation/speculative_decoding_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int main(int argc, char* argv[]) try {
std::string main_model_path = argv[1];
std::string draft_model_path = argv[2];
std::string prompt = argv[3];

// User can run main and draft model on different devices.
// Please, set device for main model in `LLMPipeline` constructor and in in `ov::genai::draft_model` for draft.
std::string main_device = "CPU", draft_device = "CPU";
Expand Down
3 changes: 1 addition & 2 deletions src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ class GenerationStream;
class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {
std::shared_ptr<GenerationStream> m_generation_stream;
ov::genai::GenerationConfig m_sampling_params;

bool is_dropped();

public:
GenerationHandleImpl(std::shared_ptr<GenerationStream> generation_stream, const ov::genai::GenerationConfig& sampling_params) :
Expand All @@ -90,6 +88,7 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {
GenerationStatus get_status();

bool can_read();
bool is_dropped();

void drop();

Expand Down
62 changes: 48 additions & 14 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <atomic>
#include <thread>

#include "text_callback_streamer.hpp"
#include "continuous_batching_impl.hpp"
#include "utils.hpp"
Expand Down Expand Up @@ -261,6 +264,9 @@ std::vector<EncodedGenerationResult>
ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<ov::Tensor>& input_ids,
const std::vector<GenerationConfig>& sampling_params,
const StreamerVariant& streamer) {
ManualTimer generate_timer("generate()");
generate_timer.start();

OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request");
OPENVINO_ASSERT(input_ids.size() == sampling_params.size());

Expand Down Expand Up @@ -299,8 +305,38 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
}
auto all_requests = m_awaiting_requests; // we need to store all requests to get results from them once generation has finished

bool continue_generation = true;
while (has_non_finished_requests() && continue_generation) {
std::atomic<bool> has_active_request = has_non_finished_requests();
GenerationHandle& generation = generations.at(0);

// create variables to make optimal thread-safe streaming
std::mutex mutex;
std::unique_lock lock(mutex);
std::condition_variable cv;

// define stream token lambda to use in `t_stream`
auto stream_tokens = [&generation, &streamer_ptr, &has_active_request, &cv, &lock]() {
while (!generation->is_dropped() && (has_active_request || streamer_ptr && generation->can_read())) {
// waiting for any tokens or request finishing
cv.wait(lock, [&generation, &has_active_request]{ return generation->can_read() || !has_active_request; });
if (streamer_ptr && generation->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generation->back();
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
generation->drop();
cv.notify_all();
break;
}
}
}
};
};

// to define streaming thread
std::thread t_stream([&stream_tokens] {
stream_tokens();
});

while (!generation->is_dropped() && has_active_request) {
try {
const auto infer_start = std::chrono::steady_clock::now();
step();
Expand All @@ -314,27 +350,23 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
}
} catch (...) {
drop_requests(); // remove all requests from pipeline state in case of exception
has_active_request = false;
cv.notify_all();
throw;
}
has_active_request = has_non_finished_requests();
cv.notify_all();
}

GenerationHandle & generation = generations.at(0);
if (streamer_ptr && generation->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generation->back();
for (const auto& gen_token : token.begin()->second.generated_ids) {
continue_generation = !streamer_ptr->put(gen_token);
if (!continue_generation) {
generation->drop();
break;
}
}
}
if (t_stream.joinable()) {
t_stream.join();
}

if (streamer_ptr) { // push streamer's cache
streamer_ptr->end();
}

if (!continue_generation) {
if (generation->is_dropped()) {
drop_requests();
} else {
OPENVINO_ASSERT(m_requests.empty(), "Internal error: current request is supposed to be dropped within step() function as completed");
Expand Down Expand Up @@ -378,6 +410,8 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
}

OPENVINO_ASSERT(results.size() == input_ids.size());

generate_timer.end();
return results;
}

Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/generation_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class GenerationStream {
}

bool can_read() {
return !m_output_queue.empty();
return !m_output_queue.empty() && !m_output_queue.full();
}

void set_generation_status(GenerationStatus status) {
Expand Down
66 changes: 46 additions & 20 deletions src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <thread>

#include "prompt_lookup_impl.hpp"
#include "text_callback_streamer.hpp"

Expand Down Expand Up @@ -109,38 +111,61 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vector<ov::Ten
}
auto all_requests = m_pipeline->get_awaiting_requests();

bool continue_generation = true;
while (has_non_finished_requests() && continue_generation) {
std::atomic<bool> continue_streaming = true, has_active_request = has_non_finished_requests();
auto& generation = generations.at(0);

// create variables to make optimal thread-safe streaming
std::mutex mutex;
std::unique_lock lock(mutex);
std::condition_variable cv;

// define stream token lambda to use in `t_stream`
auto stream_tokens = [&generation, &streamer_ptr, &has_active_request, &cv, &lock]() {
while (!generation->is_dropped() && (has_active_request || streamer_ptr && generation->can_read())) {
// waiting for any tokens or request finishing
cv.wait(lock, [&generation, &has_active_request]{ return generation->can_read() || !has_active_request; });

if (streamer_ptr && generation->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generation->back();
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
generation->drop();
cv.notify_all();
break;
}
}
}
};
};

// to define streaming thread
std::thread t_stream([&stream_tokens] {
stream_tokens();
});

while (continue_streaming && has_active_request) {
try {
const auto infer_start = std::chrono::steady_clock::now();
step();
} catch (...) {
drop_requests(); // remove all requests from pipeline state in case of exception
has_active_request = false;
cv.notify_all();
throw;
}
if (streamer_ptr) {
// not generated tokens like several prompt phase
auto& generation = generations.at(0);
if (!generation->can_read()) {
continue;
}
std::unordered_map<uint64_t, GenerationOutput> token = generation->back();
OPENVINO_ASSERT(1 <= token.size());
OPENVINO_ASSERT(1 <= token.begin()->second.generated_ids.size());
for (const auto& gen_token : token.begin()->second.generated_ids) {
continue_generation = !streamer_ptr->put(gen_token);
if (!continue_generation) {
generation->drop();
break;
}
}
}
has_active_request = has_non_finished_requests();
cv.notify_all();
}

if (t_stream.joinable()) {
t_stream.join();
}

if (streamer_ptr) { // push streamer's cache
streamer_ptr->end();
}

if (!continue_generation) {
if (generation->is_dropped()) {
drop_requests();
} else {
OPENVINO_ASSERT(m_pipeline->is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed");
Expand Down Expand Up @@ -177,6 +202,7 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vector<ov::Ten
}

OPENVINO_ASSERT(results.size() == input_ids.size());
generate_timer.end();
return results;
}

Expand Down
65 changes: 47 additions & 18 deletions src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <thread>

#include "text_callback_streamer.hpp"
#include "speculative_decoding_impl.hpp"
#include "utils.hpp"
Expand Down Expand Up @@ -235,36 +237,62 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
}
auto all_requests = get_awaiting_requests();

bool continue_generation = true;
while (has_non_finished_requests() && continue_generation) {
std::atomic<bool> has_active_request = has_non_finished_requests();
auto& generation = main_generations.at(0);

// create variables to make optimal thread-safe streaming
std::mutex mutex;
std::unique_lock lock(mutex);
std::condition_variable cv;

// define stream token lambda to use in `t_stream`
auto stream_tokens = [&generation, &streamer_ptr, &has_active_request, &cv, &lock]() {
while (!generation->is_dropped() && (has_active_request || streamer_ptr && generation->can_read())) {
// waiting for any tokens or request finishing
cv.wait(lock, [&generation, &has_active_request]{ return generation->can_read() || !has_active_request; });

if (streamer_ptr && generation->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generation->back();
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
generation->drop();
cv.notify_all();
break;
}
}
}
};
};

// to define streaming thread
std::thread t_stream([&stream_tokens] {
stream_tokens();
});

while (!generation->is_dropped() && has_active_request) {
try {
const auto infer_start = std::chrono::steady_clock::now();
step();
} catch (...) {
drop_requests(); // remove all requests from pipeline state in case of exception
has_active_request = false;
cv.notify_all();
throw;
}
if (streamer_ptr) {
auto& main_generation = main_generations.at(0);
// not generated tokens like several prompt phase
if (!main_generation->can_read()) {
continue;
}
std::unordered_map<uint64_t, GenerationOutput> token = main_generation->back();
for (const auto& gen_token : token.begin()->second.generated_ids) {
continue_generation = !streamer_ptr->put(gen_token);
if (!continue_generation) {
main_generation->drop();
break;
}
}
}
has_active_request = has_non_finished_requests();
cv.notify_all();
}

// waiting for competion of streaming
if (t_stream.joinable()) {
t_stream.join();
}

if (streamer_ptr) { // push streamer's cache
streamer_ptr->end();
}

if (!continue_generation) {
if (generation->is_dropped()) {
drop_requests();
} else {
OPENVINO_ASSERT(is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed");
Expand Down Expand Up @@ -301,6 +329,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
}

OPENVINO_ASSERT(results.size() == input_ids.size());
generate_timer.end();
return results;
}

Expand Down
9 changes: 8 additions & 1 deletion src/cpp/src/synchronized_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class SynchronizedQueue
std::queue<T> m_queue;
std::mutex m_mutex;
std::condition_variable m_cv;
size_t m_taken_element_cnt = 0;

public:
SynchronizedQueue() = default;
Expand All @@ -22,7 +23,8 @@ class SynchronizedQueue

T back() {
std::unique_lock<std::mutex> lock(m_mutex);
m_cv.wait(lock, [this]{return !m_queue.empty();});
m_cv.wait(lock, [this]{return !m_queue.empty(); });
m_taken_element_cnt = m_queue.size();
return m_queue.back();
}

Expand All @@ -44,4 +46,9 @@ class SynchronizedQueue
std::unique_lock<std::mutex> lock(m_mutex);
return m_queue.empty();
}

bool full() {
std::unique_lock<std::mutex> lock(m_mutex);
return m_taken_element_cnt == m_queue.size();
}
};
Loading