Skip to content

Commit

Permalink
Add a choice of how to end streaming from callback: STOP or CANCEL
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Jan 16, 2025
1 parent bba7b87 commit 488b83b
Show file tree
Hide file tree
Showing 28 changed files with 245 additions and 84 deletions.
6 changes: 4 additions & 2 deletions samples/cpp/text_generation/chat_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ int main(int argc, char* argv[]) try {

ov::genai::GenerationConfig config;
config.max_new_tokens = 100;
std::function<bool(std::string)> streamer = [](std::string word) {

std::function<ov::genai::StreamerRunningStatus(std::string)> streamer = [](std::string word) {
std::cout << word << std::flush;
// Return flag corresponds whether generation should be stopped.
// false means continue generation.
return false;

return ov::genai::StreamerRunningStatus::RUNNING;
};

pipe.start_chat();
Expand Down
6 changes: 2 additions & 4 deletions samples/python/text_generation/chat_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
import openvino_genai


def streamer(subword):
def streamer(subword) -> openvino_genai.StreamerRunningStatus:
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
# False means continue generation.
return False

return openvino_genai.StreamerRunningStatus.RUNNING

def main():
parser = argparse.ArgumentParser()
Expand Down
5 changes: 5 additions & 0 deletions src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <unordered_map>

#include "openvino/genai/generation_config.hpp"
#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/visibility.hpp"
#include "openvino/genai/perf_metrics.hpp"

Expand Down Expand Up @@ -34,6 +35,10 @@ struct EncodedGenerationResult {

// PerfMetrics but with empty tokenization/detokenization durations.
PerfMetrics perf_metrics;


// Status of streaming
StreamerRunningStatus m_streaming_status = ov::genai::StreamerRunningStatus::UNDEF;
};

enum class GenerationFinishReason {
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace ov {
namespace genai {

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::function<StreamerRunningStatus(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using OptionalGenerationConfig = std::optional<GenerationConfig>;
using EncodedInputs = std::variant<ov::Tensor, TokenizedInputs>;
using StringInputs = std::variant<std::string, std::vector<std::string>>;
Expand Down
16 changes: 16 additions & 0 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,28 @@
#pragma once

#include "openvino/genai/tokenizer.hpp"
#include <variant>

namespace ov {
namespace genai {

enum class StreamerRunningStatus {
UNDEF = 0, // Streaming is not run
RUNNING = 1, // Continue to run of inference
STOP = 2, // Stop generation, keep history as is, KV cache includes last request and generated tokens
CANCEL = 3 // Stop generate, drop last prompt and all generated tokens from history, KV cache include history but last step
};

using CallbackTypeVariant = std::variant<bool, StreamerRunningStatus>;

/**
* @brief base class for streamers. In order to use inherit from from this class and implement put, and methods
*
* @param m_tokenizer tokenizer
*/
class OPENVINO_GENAI_EXPORTS StreamerBase {
protected:
StreamerRunningStatus streaming_finish_status = StreamerRunningStatus::UNDEF;
public:
/// @brief put is called every time new token is decoded,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
Expand All @@ -22,6 +34,10 @@ class OPENVINO_GENAI_EXPORTS StreamerBase {
/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
virtual void end() = 0;

virtual StreamerRunningStatus get_finish_streaming_reason() {
return streaming_finish_status;
}

virtual ~StreamerBase();
};

Expand Down
6 changes: 6 additions & 0 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
},
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
},
[this](const std::function<ov::genai::StreamerRunningStatus(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
}
}, streamer);

Expand Down Expand Up @@ -354,6 +357,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
result.m_generation_ids.resize(num_outputs);
result.m_scores.resize(num_outputs);

if (streamer_ptr)
result.m_streaming_status = streamer_ptr->get_finish_streaming_reason();

for (size_t i = 0; i < num_outputs; ++i) {
const auto & sequence = sequences[i];
const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_prob();
Expand Down
6 changes: 5 additions & 1 deletion src/cpp/src/icontinuous_batching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
const auto decode_start = std::chrono::steady_clock::now();
generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx)));
raw_counters.detokenization_durations.emplace_back(std::chrono::steady_clock::now() - decode_start);
if (m_is_chat_conversation && 0 == idx) {
if (m_is_chat_conversation && 0 == idx && res.m_streaming_status != ov::genai::StreamerRunningStatus::CANCEL) {
m_history.push_back({{"role", "assistant"}, {"content", generated.back()}});
}
}
Expand All @@ -98,6 +98,10 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
});
}

// if streaming was canceled, prompt/answer of current step shouldn't be presented in history, so let's remove prompt from history
if (m_is_chat_conversation && !encoded.empty() && encoded[0].m_streaming_status == ov::genai::StreamerRunningStatus::CANCEL)
m_history.pop_back();

return decoded;
}
}
2 changes: 2 additions & 0 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ std::pair<ov::AnyMap, ov::genai::static_llm::ModelConfigDesc> split_model_descr(
std::pair<std::string, Any> streamer(StreamerVariant func) {
if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::shared_ptr<StreamerBase>>(*streamer_obj)};
} else if (auto streamer_obj = std::get_if<std::function<StreamerRunningStatus(std::string)>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::function<StreamerRunningStatus(std::string)>>(*streamer_obj)};
} else {
auto callback = std::get<std::function<bool(std::string)>>(func);
return {utils::STREAMER_ARG_NAME, Any::make<std::function<bool(std::string)>>(callback)};
Expand Down
74 changes: 49 additions & 25 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ StatefulLLMPipeline::StatefulLLMPipeline(
const ov::genai::GenerationConfig& generation_config)
: LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) {
utils::apply_slice_before_matmul_transformation(model);
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);
m_kv_history_manager.kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);

ov::CompiledModel compiled_model;
if (auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters)) {
Expand Down Expand Up @@ -86,6 +86,9 @@ DecodedResults StatefulLLMPipeline::generate(

TokenizedInputs encoded_input;

std::string prev_templated_chat_history(m_templated_chat_history);
std::vector<int64_t> prev_tokenized_chat_history(m_tokenized_chat_history);

if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts");
encoded_input = m_tokenizer.encode(*input_vector);
Expand All @@ -104,7 +107,7 @@ DecodedResults StatefulLLMPipeline::generate(

m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
// Do not add special tokens in chat scenario to be aligned with HF.
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false));
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
Expand All @@ -116,21 +119,24 @@ DecodedResults StatefulLLMPipeline::generate(
if (!m_tokenized_chat_history.empty()) {
std::set<int64_t> stop_tokens = config.stop_token_ids;
trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
m_trust_encoded_history = trusted_history_length == SIZE_MAX;
}

if (m_tokenized_chat_history.empty()) {
encoded_input = new_chat_tokens;
} else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) {
// does_kv_cache_need_to_update will be true here if beam search is activated
} else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_history_cache_need_to_update()) {
// does_history_cache_need_to_update will be true here if beam search is activated
// in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly
// if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager
if (m_kv_history_manager.does_kv_cache_need_to_update()) {
if (m_kv_history_manager.does_history_cache_need_to_update()) {
trusted_history_length = m_kv_history_manager.trusted_history_length;
} else {
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length;
size_t num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length;
// if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it
m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0;
num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0;

// if streaming was used and canceled on prev step, num_tokens_to_remove_from_kv_cache could be already set and it will be bigger as include answer + prompt
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = num_tokens_to_remove_from_kv_cache > m_kv_history_manager.num_tokens_to_remove_from_kv_cache ?
num_tokens_to_remove_from_kv_cache : m_kv_history_manager.num_tokens_to_remove_from_kv_cache;
}

ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
Expand Down Expand Up @@ -169,11 +175,19 @@ DecodedResults StatefulLLMPipeline::generate(
auto decode_stop_time = std::chrono::steady_clock::now();

if (is_chat_conversation) {
// Tail of chat template is missing in KV cache.
// Find the tail to concatenate it with the next input prompt.
auto answer = decoded_results.texts[0];
m_templated_chat_history.append(answer);
m_history.push_back({{"role", "assistant"}, {"content", answer}});
if (m_chat_generation_finish_status == ov::genai::StreamerRunningStatus::CANCEL) {
// If chat generation process was canceled by user, let's rollback to previous state of history
m_history.pop_back();
m_kv_history_manager.num_tokens_to_remove_from_kv_cache += m_tokenized_chat_history.size() - prev_tokenized_chat_history.size();
m_templated_chat_history = prev_templated_chat_history;
m_tokenized_chat_history = prev_tokenized_chat_history;
} else {
// Tail of chat template is missing in KV cache.
// Find the tail to concatenate it with the next input prompt.
auto answer = decoded_results.texts[0];
m_templated_chat_history.append(answer);
m_history.push_back({{"role", "assistant"}, {"content", answer}});
}
}

// generate_durations
Expand Down Expand Up @@ -218,6 +232,8 @@ EncodedResults StatefulLLMPipeline::generate(
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history));

size_t real_input_ids_size = input_ids.get_shape().at(1);

// Tail of previous output in chat mode is missing in KV cache.
if (m_last_disappeared_token.has_value()) {
attention_mask = ov::genai::utils::push_front_inputs(attention_mask, 1);
Expand All @@ -239,7 +255,9 @@ EncodedResults StatefulLLMPipeline::generate(
streamer_ptr = nullptr;
} else if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&streamer)) {
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
} else if (auto callback = std::get_if<std::function<ov::genai::StreamerRunningStatus(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
}

Expand All @@ -254,7 +272,8 @@ EncodedResults StatefulLLMPipeline::generate(
"(input_ids, attention_mask, position_ids, beam_idx) "
"but you have '" + std::to_string(num_inputs) + "' inputs");

ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache, m_kv_cache_seq_length_axis, m_adapter_controller);
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache,
m_kv_history_manager.kv_cache_seq_length_axis, m_adapter_controller);

size_t kv_cache_len = 0;
ov::Tensor concatenated_attention_mask;
Expand Down Expand Up @@ -292,8 +311,7 @@ EncodedResults StatefulLLMPipeline::generate(
m_adapter_controller->apply(m_model_runner, config.adapters);
}

if (is_chat_conversation && !m_trust_encoded_history) {
m_trust_encoded_history = true;
if (is_chat_conversation) {
m_kv_history_manager.reset();
}

Expand Down Expand Up @@ -321,9 +339,11 @@ EncodedResults StatefulLLMPipeline::generate(
m_sampler.set_seed(config.rng_seed);
}

ov::genai::EncodedResults result;
std::tie(result, m_last_disappeared_token) = get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
ov::genai::utils::GenerationFinishInfo finish_info = get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask,
streamer_ptr, m_sampler, requests, position_ids, std::nullopt);
ov::genai::EncodedResults result = finish_info.results;
m_last_disappeared_token = finish_info.probably_disappeared_token;
m_chat_generation_finish_status = finish_info.streaming_finish_status;

if (is_chat_conversation) {
// force remove from kv_cache last answer
Expand All @@ -332,15 +352,21 @@ EncodedResults StatefulLLMPipeline::generate(
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
}

std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
if (m_chat_generation_finish_status == ov::genai::StreamerRunningStatus::CANCEL) {
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;

if (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
m_tokenized_chat_history.resize(m_tokenized_chat_history.size() - real_input_ids_size);
m_kv_history_manager.num_tokens_to_remove_from_kv_cache += real_input_ids_size;
}
} else {
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
}
} else {
reset_kv_state();
m_last_disappeared_token = std::nullopt;
}

if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));

auto stop_time = std::chrono::steady_clock::now();

// If is called without tokenization then that stat will not be reported.
Expand All @@ -354,7 +380,6 @@ EncodedResults StatefulLLMPipeline::generate(

void StatefulLLMPipeline::start_chat(const std::string& system_message) {
is_chat_conversation = true;
m_trust_encoded_history = true;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
Expand Down Expand Up @@ -387,7 +412,6 @@ void StatefulLLMPipeline::reset_kv_state() {

void StatefulLLMPipeline::finish_chat() {
is_chat_conversation = false;
m_trust_encoded_history = true;
m_kv_history_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
m_last_disappeared_token = std::nullopt;
Expand Down
5 changes: 3 additions & 2 deletions src/cpp/src/llm_pipeline_stateful.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
// If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache
// If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0};
size_t m_kv_cache_seq_length_axis = 2;
ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0, 2};
// Finish reason of last generation for chat scenario
ov::genai::StreamerRunningStatus m_chat_generation_finish_status = ov::genai::StreamerRunningStatus::UNDEF;

void reset_kv_state();
public:
Expand Down
Loading

0 comments on commit 488b83b

Please sign in to comment.