Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Jan 16, 2025
1 parent d86d142 commit 68cd6b4
Show file tree
Hide file tree
Showing 33 changed files with 316 additions and 216 deletions.
5 changes: 2 additions & 3 deletions samples/cpp/text_generation/chat_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@ int main(int argc, char* argv[]) try {
ov::genai::GenerationConfig config;
config.max_new_tokens = 100;

std::function<ov::genai::StreamerRunningStatus(std::string)> streamer = [](std::string word) {
auto streamer = [](std::string word) {
std::cout << word << std::flush;
// Return flag corresponds whether generation should be stopped.
// false means continue generation.

return ov::genai::StreamerRunningStatus::RUNNING;
return ov::genai::GenerationStatus::RUNNING;
};

pipe.start_chat();
Expand Down
2 changes: 1 addition & 1 deletion samples/python/text_generation/chat_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
# False means continue generation.
return False
return openvino_genai.GenerationStatus.RUNNING

def main():
parser = argparse.ArgumentParser()
Expand Down
17 changes: 9 additions & 8 deletions src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <unordered_map>

#include "openvino/genai/generation_config.hpp"
#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/visibility.hpp"
#include "openvino/genai/perf_metrics.hpp"

Expand All @@ -16,8 +15,8 @@ enum class GenerationStatus {
RUNNING = 0, // Default status for ongoing generation
FINISHED = 1, // Status set when generation has been finished
IGNORED = 2, // Status set when generation run into out-of-memory condition and could not be continued
DROPPED_BY_PIPELINE = 3, // Currently not used, TODO: implement abort functionality
DROPPED_BY_HANDLE = 4 // Status set when generation handle is dropped
CANCEL = 3, // Status set when generation handle is canceled
STOP = 4 // Status set when generation handle is stopped
};

struct EncodedGenerationResult {
Expand All @@ -35,10 +34,6 @@ struct EncodedGenerationResult {

// PerfMetrics but with empty tokenization/detokenization durations.
PerfMetrics perf_metrics;


// Status of streaming
StreamerRunningStatus m_streaming_status = ov::genai::StreamerRunningStatus::UNDEF;
};

enum class GenerationFinishReason {
Expand Down Expand Up @@ -79,7 +74,9 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {
std::shared_ptr<GenerationStream> m_generation_stream;
ov::genai::GenerationConfig m_sampling_params;

bool is_dropped();
bool is_stopped();

bool is_canceled();

public:
GenerationHandleImpl(std::shared_ptr<GenerationStream> generation_stream, const ov::genai::GenerationConfig& sampling_params) :
Expand All @@ -98,6 +95,10 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {

void drop();

void stop();

void cancel();

GenerationOutputs back();
// Reads result of a generation for single iteration
GenerationOutputs read();
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace ov {
namespace genai {

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::function<StreamerRunningStatus(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::function<UintCallbackStreamerResult(std::string)>, std::function<GenerationStatus(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using OptionalGenerationConfig = std::optional<GenerationConfig>;
using EncodedInputs = std::variant<ov::Tensor, TokenizedInputs>;
using StringInputs = std::variant<std::string, std::vector<std::string>>;
Expand Down
20 changes: 8 additions & 12 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,34 @@
#pragma once

#include "openvino/genai/tokenizer.hpp"
#include "openvino/genai/generation_handle.hpp"
#include <variant>

namespace ov {
namespace genai {

enum class StreamerRunningStatus {
UNDEF = 0, // Streaming is not run
RUNNING = 1, // Continue to run of inference
STOP = 2, // Stop generation, keep history as is, KV cache includes last request and generated tokens
CANCEL = 3 // Stop generate, drop last prompt and all generated tokens from history, KV cache include history but last step
// uint16_t for Python API here
struct UintCallbackStreamerResult {
uint16_t result;
};

using CallbackTypeVariant = std::variant<bool, StreamerRunningStatus>;
using CallbackTypeVariant = std::variant<bool, UintCallbackStreamerResult, ov::genai::GenerationStatus, std::monostate>;

/**
* @brief base class for streamers. In order to use inherit from from this class and implement put, and methods
*
* @param m_tokenizer tokenizer
*/
class OPENVINO_GENAI_EXPORTS StreamerBase {
protected:
StreamerRunningStatus streaming_finish_status = StreamerRunningStatus::UNDEF;
public:
ov::genai::GenerationStatus m_streaming_finish_status = ov::genai::GenerationStatus::RUNNING;
/// @brief put is called every time new token is decoded,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
virtual bool put(int64_t token) = 0;

/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
virtual void end() = 0;

virtual StreamerRunningStatus get_finish_streaming_reason() {
return streaming_finish_status;
ov::genai::GenerationStatus get_finish_reason() {
return m_streaming_finish_status;
}

virtual ~StreamerBase();
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/continuous_batching_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
std::vector<std::string> plain_replies;
std::vector<float> plain_scores;
for (GenerationResult& res : generated) {
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus");
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::STOP || res.m_status == GenerationStatus::CANCEL, "Got unfinished GenerationStatus");
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_replies));
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
}
Expand Down Expand Up @@ -189,7 +189,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
std::vector<std::vector<int64_t>> plain_tokens;
std::vector<float> plain_scores;
for (EncodedGenerationResult& res : generated) {
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus");
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::STOP, "Got unfinished GenerationStatus");
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_tokens));
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
}
Expand Down
27 changes: 6 additions & 21 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,20 +276,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
}
set_adapters(sampling_params[0].adapters);

const std::shared_ptr<StreamerBase>& streamer_ptr = std::visit(overloaded{
[](std::monostate) -> std::shared_ptr<StreamerBase> {
return nullptr;
},
[](const std::shared_ptr<StreamerBase>& streamer) {
return streamer;
},
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
},
[this](const std::function<ov::genai::StreamerRunningStatus(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
}
}, streamer);
const std::shared_ptr<StreamerBase>& streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer);

OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && sampling_params[0].num_return_sequences == 1 &&
(sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()),
Expand Down Expand Up @@ -320,13 +307,13 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
throw;
}

GenerationHandle & generation = generations.at(0);
GenerationHandle& generation = generations.at(0);
if (streamer_ptr && generation->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generation->back();
for (const auto& gen_token : token.begin()->second.generated_ids) {
continue_generation = !streamer_ptr->put(gen_token);
if (!continue_generation) {
generation->drop();
streamer_ptr->get_finish_reason() == GenerationStatus::CANCEL ? generation->cancel() : generation->stop();
break;
}
}
Expand Down Expand Up @@ -356,9 +343,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
result.m_request_id = request_id;
result.m_generation_ids.resize(num_outputs);
result.m_scores.resize(num_outputs);

if (streamer_ptr)
result.m_streaming_status = streamer_ptr->get_finish_streaming_reason();
result.m_status = request->get_generation_stream()->get_status();

for (size_t i = 0; i < num_outputs; ++i) {
const auto & sequence = sequences[i];
Expand Down Expand Up @@ -391,7 +376,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_free_non_running_reque
std::vector<SequenceGroup::Ptr>::iterator requests_iterator = m_requests.begin();
while (requests_iterator != m_requests.end()) {
const auto& request = *requests_iterator;
if(request->has_finished() || request->out_of_memory() || request->handle_dropped()) {
if(request->has_finished() || request->out_of_memory() || request->handle_stopped() || request->handle_canceled()) {
for (const auto& sequence: request->get_sequences()) {
if (m_scheduler->has_block_table(sequence->get_id())) {
m_scheduler->free_sequence(sequence->get_id());
Expand All @@ -409,7 +394,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_notify_requests_droppe
// Notify the last time by pushing empty output
// This causes read() to unblock by adding anything to the queue
for (SequenceGroup::Ptr& request : m_requests) {
if (request->handle_dropped())
if (request->handle_stopped() || request->handle_canceled())
request->push_empty_outputs();
}
}
Expand Down
28 changes: 20 additions & 8 deletions src/cpp/src/generation_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,44 @@
using namespace ov::genai;

GenerationHandleImpl::~GenerationHandleImpl() {
drop();
stop();
}

GenerationStatus GenerationHandleImpl::get_status() {
return m_generation_stream->get_status();
}

bool GenerationHandleImpl::can_read() {
return !is_dropped() && m_generation_stream->can_read();
return !is_canceled() && !is_stopped() && m_generation_stream->can_read();
}

bool GenerationHandleImpl::is_dropped() {
return get_status() == GenerationStatus::DROPPED_BY_HANDLE;
bool GenerationHandleImpl::is_stopped() {
return get_status() == GenerationStatus::STOP;
}

bool GenerationHandleImpl::is_canceled() {
return get_status() == GenerationStatus::CANCEL;
}

void GenerationHandleImpl::drop() {
m_generation_stream->drop();
m_generation_stream->stop();
}

void GenerationHandleImpl::stop() {
m_generation_stream->stop();
}

void GenerationHandleImpl::cancel() {
m_generation_stream->cancel();
}

std::unordered_map<uint64_t, GenerationOutput> GenerationHandleImpl::back() {
OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped.");
OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped.");
return m_generation_stream->back();
}

std::unordered_map<uint64_t, GenerationOutput> GenerationHandleImpl::read() {
OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped.");
OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped.");
return m_generation_stream->read();
}

Expand All @@ -57,7 +69,7 @@ void add_partial_result(std::unordered_map<uint64_t, GenerationOutput>& partial_
}

std::vector<GenerationOutput> GenerationHandleImpl::read_all() {
OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped.");
OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped.");
std::vector<GenerationOutput> results;
std::unordered_map<uint64_t, GenerationOutput> partial_results;
// We iterate until generation is running or there are tokens we haven't read yet
Expand Down
9 changes: 7 additions & 2 deletions src/cpp/src/generation_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,14 @@ class GenerationStream {
return m_status;
}

void drop() {
void stop() {
std::lock_guard<std::mutex> lock(m_mutex);
m_status = GenerationStatus::DROPPED_BY_HANDLE;
m_status = GenerationStatus::STOP;
}

void cancel() {
std::lock_guard<std::mutex> lock(m_mutex);
m_status = GenerationStatus::CANCEL;
}
};
}
4 changes: 2 additions & 2 deletions src/cpp/src/icontinuous_batching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
const auto decode_start = std::chrono::steady_clock::now();
generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx)));
raw_counters.detokenization_durations.emplace_back(std::chrono::steady_clock::now() - decode_start);
if (m_is_chat_conversation && 0 == idx && res.m_streaming_status != ov::genai::StreamerRunningStatus::CANCEL) {
if (m_is_chat_conversation && 0 == idx && res.m_status != ov::genai::GenerationStatus::CANCEL) {
m_history.push_back({{"role", "assistant"}, {"content", generated.back()}});
}
}
Expand All @@ -99,7 +99,7 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
}

// if streaming was canceled, prompt/answer of current step shouldn't be presented in history, so let's remove prompt from history
if (m_is_chat_conversation && !encoded.empty() && encoded[0].m_streaming_status == ov::genai::StreamerRunningStatus::CANCEL)
if (m_is_chat_conversation && !encoded.empty() && encoded[0].m_status == ov::genai::GenerationStatus::CANCEL)
m_history.pop_back();

return decoded;
Expand Down
8 changes: 5 additions & 3 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ std::pair<ov::AnyMap, ov::genai::static_llm::ModelConfigDesc> split_model_descr(
std::pair<std::string, Any> streamer(StreamerVariant func) {
if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::shared_ptr<StreamerBase>>(*streamer_obj)};
} else if (auto streamer_obj = std::get_if<std::function<StreamerRunningStatus(std::string)>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::function<StreamerRunningStatus(std::string)>>(*streamer_obj)};
} else {
} else if (auto streamer_obj = std::get_if<std::function<UintCallbackStreamerResult(std::string)>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::function<UintCallbackStreamerResult(std::string)>>(*streamer_obj)};
} else if (auto streamer_obj = std::get_if<std::function<GenerationStatus(std::string)>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::function<GenerationStatus(std::string)>>(*streamer_obj)};
} else {
auto callback = std::get<std::function<bool(std::string)>>(func);
return {utils::STREAMER_ARG_NAME, Any::make<std::function<bool(std::string)>>(callback)};
}
Expand Down
15 changes: 3 additions & 12 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ DecodedResults StatefulLLMPipeline::generate(
auto decode_stop_time = std::chrono::steady_clock::now();

if (is_chat_conversation) {
if (m_chat_generation_finish_status == ov::genai::StreamerRunningStatus::CANCEL) {
if (m_chat_generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
// If chat generation process was canceled by user, let's rollback to previous state of history
m_history.pop_back();
m_kv_history_manager.num_tokens_to_remove_from_kv_cache += m_tokenized_chat_history.size() - prev_tokenized_chat_history.size();
Expand Down Expand Up @@ -250,16 +250,7 @@ EncodedResults StatefulLLMPipeline::generate(
// Stateful pipeline does not provide logprobs for prompt tokens
OPENVINO_ASSERT(config.echo == false, "Echo is not supported in the stateful pipeline");

std::shared_ptr<StreamerBase> streamer_ptr;
if (auto streamer_obj = std::get_if<std::monostate>(&streamer)) {
streamer_ptr = nullptr;
} else if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&streamer)) {
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<ov::genai::StreamerRunningStatus(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
}
std::shared_ptr<StreamerBase> streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer);

auto batch_size = input_ids.get_shape().at(0);
OPENVINO_ASSERT(streamer_ptr == nullptr || batch_size == 1 && config.num_return_sequences == 1 &&
Expand Down Expand Up @@ -352,7 +343,7 @@ EncodedResults StatefulLLMPipeline::generate(
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
}

if (m_chat_generation_finish_status == ov::genai::StreamerRunningStatus::CANCEL) {
if (m_chat_generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;

if (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/llm_pipeline_stateful.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0, 2};
// Finish reason of last generation for chat scenario
ov::genai::StreamerRunningStatus m_chat_generation_finish_status = ov::genai::StreamerRunningStatus::UNDEF;
ov::genai::GenerationStatus m_chat_generation_finish_status = ov::genai::GenerationStatus::RUNNING;

void reset_kv_state();
public:
Expand Down
Loading

0 comments on commit 68cd6b4

Please sign in to comment.