Skip to content

Commit 2a9af43

Browse files
authored
Stateful to Stateless Transformation for LLMs (#25150)
Transformation that undoing make_stateful from optimum-intel. ### How to use in Python ```python import openvino as ov from openvino._offline_transformations import stateful_to_stateless_transformation core = ov.Core() model = core.read_model('your_chatty_stateful_model_right_from_vanilla_optimum_intel.xml') stateful_to_stateless_transformation(model) # use `model` ``` ### How to use in C++ ```c++ #include <openvino/openvino.hpp> #include <openvino/pass/stateful_to_stateless.hpp> int main() { auto core = ov::Core(); auto model = core.read_model("your_chatty_stateful_model_right_from_vanilla_optimum_intel.xml"); ov::pass::StatefulToStateless().run_on_model(model); // use `model` } ``` ### TODO - [x] Restore the original order of inputs/output (now they are not globally ordered, but kv inputs corresponds to kv outputs by indices with a proper offset). - [x] Restore the original names of inputs and outputs based on optimum-intel conventions in make_stateful.
1 parent c901a26 commit 2a9af43

File tree

7 files changed

+279
-0
lines changed

7 files changed

+279
-0
lines changed

.github/workflows/job_pytorch_models_tests.yml

+10
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,16 @@ jobs:
153153
USE_SYSTEM_CACHE: False
154154
OP_REPORT_FILE: ${{ env.INSTALL_TEST_DIR }}/TEST-torch_unsupported_ops.log
155155

156+
- name: StatefulToStateless Test
157+
if: always()
158+
run: |
159+
export PYTHONPATH=${MODEL_HUB_TESTS_INSTALL_DIR}:$PYTHONPATH
160+
python3 -m pytest ${MODEL_HUB_TESTS_INSTALL_DIR}/pytorch/test_stateful_to_stateless_transformation.py -m ${TYPE} --html=${INSTALL_TEST_DIR}/TEST-torch_stateful_to_stateless_tests.html --self-contained-html -v --tb=short
161+
env:
162+
TYPE: ${{ inputs.event == 'schedule' && 'nightly' || 'precommit'}}
163+
TEST_DEVICE: CPU
164+
USE_SYSTEM_CACHE: False
165+
156166
- name: Reformat unsupported ops file
157167
if: '!cancelled()'
158168
run: |

src/bindings/python/src/openvino/_offline_transformations/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
from openvino._pyopenvino._offline_transformations import compress_quantize_weights_transformation
1919
from openvino._pyopenvino._offline_transformations import convert_sequence_to_tensor_iterator_transformation
2020
from openvino._pyopenvino._offline_transformations import paged_attention_transformation
21+
from openvino._pyopenvino._offline_transformations import stateful_to_stateless_transformation

src/bindings/python/src/pyopenvino/core/offline_transformations.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <openvino/pass/make_stateful.hpp>
1111
#include <openvino/pass/sdpa_to_paged_attention.hpp>
1212
#include <openvino/pass/serialize.hpp>
13+
#include <openvino/pass/stateful_to_stateless.hpp>
1314
#include <pruning.hpp>
1415
#include <transformations/common_optimizations/compress_float_constants.hpp>
1516
#include <transformations/common_optimizations/fused_names_cleanup.hpp>
@@ -137,4 +138,13 @@ void regmodule_offline_transformations(py::module m) {
137138
manager.run_passes(model);
138139
},
139140
py::arg("model"));
141+
142+
m_offline_transformations.def(
143+
"stateful_to_stateless_transformation",
144+
[](std::shared_ptr<ov::Model> model) {
145+
ov::pass::Manager manager;
146+
manager.register_pass<ov::pass::StatefulToStateless>();
147+
manager.run_passes(model);
148+
},
149+
py::arg("model"));
140150
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/pass/pass.hpp"
8+
9+
namespace ov {
10+
namespace pass {
11+
/**
12+
* @brief The transformation converts KV cache state back to stateless form.
13+
* \ingroup ov_pass_cpp_api
14+
*/
15+
class OPENVINO_API StatefulToStateless : public ModelPass {
16+
public:
17+
OPENVINO_RTTI("StatefulToStateless");
18+
19+
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
20+
};
21+
} // namespace pass
22+
} // namespace ov
+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/pass/stateful_to_stateless.hpp"
6+
7+
#include <regex>
8+
#include <string>
9+
10+
#include "openvino/cc/pass/itt.hpp"
11+
#include "openvino/op/assign.hpp"
12+
#include "openvino/op/gather.hpp"
13+
#include "openvino/op/read_value.hpp"
14+
#include "openvino/pass/manager.hpp"
15+
#include "transformations/utils/utils.hpp"
16+
17+
using namespace ov::op;
18+
19+
namespace {
20+
21+
std::shared_ptr<ov::Node> set_name(std::shared_ptr<ov::Node> node, const std::string& name) {
22+
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a
23+
// given single name)
24+
node->set_friendly_name(name);
25+
OPENVINO_ASSERT(node->get_output_size() == 1);
26+
node->get_output_tensor(0).set_names({name});
27+
return node;
28+
}
29+
30+
// Templated method that has the same effect as not templated `set_name` but saves Op type for convenient calls chaining
31+
template <typename T>
32+
inline std::shared_ptr<T> set_name(std::shared_ptr<T> node, const std::string& name) {
33+
set_name(std::dynamic_pointer_cast<ov::Node>(node), name);
34+
return node;
35+
}
36+
37+
std::shared_ptr<v0::Parameter> get_parameter_by_tensor_name(const std::shared_ptr<ov::Model>& model,
38+
const std::string& name) {
39+
for (const auto& param : model->get_parameters()) {
40+
if (param->get_output_tensor(0).get_names().count(name))
41+
return param;
42+
}
43+
return nullptr; // nullptr and return type are only difference from ov::Model::input(name)
44+
}
45+
46+
struct Variable {
47+
struct Context {
48+
// to hold compiled once regex for all Variable instances
49+
const std::regex naming_convention =
50+
std::regex(R"((past_key_values\.(\d+)\.(key|value))(present\.(\d+)\.(key|value)))");
51+
};
52+
53+
Variable(const Context& context, const std::string& variable_name) : variable_name(variable_name) {
54+
// Try to decode original naming of the corresponding input and output in the stateless model
55+
std::smatch match;
56+
if (std::regex_match(variable_name, match, context.naming_convention)) {
57+
input_name = match[1].str();
58+
output_name = match[4].str();
59+
auto input_index = match[2].str();
60+
auto output_index = match[5].str();
61+
if (input_index == output_index && input_index.length() <= std::numeric_limits<int>::digits10) {
62+
index = std::stoi(input_index) * 2 + int(match[3].str() == "value"); // order key before value
63+
} else {
64+
index = -1;
65+
}
66+
} else {
67+
// Variable name doesn't follow the expected naming convention. It doens't prevent forming
68+
// a correct stateless model but doesn't give a way to restore all names and inputs/outputs ordering
69+
// accurately.
70+
input_name = "input_restored." + variable_name;
71+
output_name = "output_restored." + variable_name;
72+
index = -1;
73+
}
74+
}
75+
76+
int index; // layer index, -1 means the index isn't known
77+
std::string variable_name; // original variable_id
78+
std::string input_name; // restored name of input
79+
std::string output_name; // restored name of output
80+
};
81+
82+
typedef std::vector<Variable> Variables;
83+
84+
void restore_kv_cache_order(Variables& variables, const std::unordered_map<std::string, size_t>& var_index_by_var_id) {
85+
// Try to restore variable order based on the known naming convention from optimum-intel.
86+
// If names are not satisfy the expected convention, fallback to use order based on var_index_by_var_id
87+
// Sort items that do satisfy the naming conventions before items that don't satisfy.
88+
89+
std::stable_sort(variables.begin(), variables.end(), [&](const Variable& a, const Variable& b) {
90+
if (a.index >= 0 && b.index >= 0) {
91+
return a.index < b.index;
92+
} else if (a.index >= 0 && b.index < 0) {
93+
return true;
94+
} else if (a.index < 0 && b.index >= 0) {
95+
return false;
96+
} else { // a.index < 0 && b.index < 0
97+
return var_index_by_var_id.at(a.variable_name) < var_index_by_var_id.at(b.variable_name);
98+
}
99+
});
100+
}
101+
102+
} // namespace
103+
104+
bool ov::pass::StatefulToStateless::run_on_model(const std::shared_ptr<ov::Model>& model) {
105+
RUN_ON_MODEL_SCOPE(StatefulToStateless);
106+
107+
auto beam_idx = get_parameter_by_tensor_name(model, "beam_idx");
108+
Variables variables; // to collect variables corresponding to future_params
109+
variables.reserve(model->get_sinks().size());
110+
Variable::Context context;
111+
std::unordered_map<std::string, std::shared_ptr<ov::Node>>
112+
future_params; // to collect nodes, each with a single output that will be replaced by new parameters
113+
if (beam_idx) {
114+
for (const ov::Input<ov::Node>& input : beam_idx->get_output_target_inputs(0)) {
115+
if (auto gather = std::dynamic_pointer_cast<op::util::GatherBase>(input.get_node()->shared_from_this())) {
116+
auto read_value =
117+
std::dynamic_pointer_cast<op::util::ReadValueBase>(gather->get_input_node_shared_ptr(0));
118+
OPENVINO_ASSERT(read_value,
119+
"Unexpected model topology in StatefulToStateless: no ReadValue is found at the first "
120+
"input of Gather by `beam_idx` parameter");
121+
auto variable_name = read_value->get_variable_id();
122+
variables.push_back(Variable(context, variable_name));
123+
future_params[variable_name] = gather;
124+
}
125+
}
126+
} else {
127+
OPENVINO_THROW(
128+
"Stateful models without `beam_idx` input are not supported in StatefulToStateless transformation");
129+
}
130+
model->remove_parameter(beam_idx);
131+
132+
typedef std::shared_ptr<op::util::AssignBase> PAssign;
133+
std::unordered_map<std::string, PAssign> assigns_by_var_id;
134+
std::unordered_map<std::string, size_t> assign_index_by_var_id;
135+
const auto& sinks = model->get_sinks();
136+
for (size_t i = 0; i < sinks.size(); ++i) {
137+
if (auto assign = std::dynamic_pointer_cast<op::util::AssignBase>(sinks[i])) {
138+
const auto& var_id = assign->get_variable_id();
139+
assigns_by_var_id[var_id] = assign;
140+
assign_index_by_var_id[var_id] = i;
141+
}
142+
}
143+
144+
restore_kv_cache_order(variables, assign_index_by_var_id);
145+
146+
ov::ParameterVector new_parameters;
147+
ov::ResultVector new_results;
148+
new_parameters.reserve(variables.size());
149+
new_results.reserve(variables.size());
150+
151+
for (const auto& variable_id : variables) {
152+
auto future_param = future_params[variable_id.variable_name];
153+
auto parameter = ::set_name(std::make_shared<v0::Parameter>(future_param->get_output_element_type(0),
154+
future_param->get_output_partial_shape(0)),
155+
variable_id.input_name);
156+
157+
replace_node(future_param, parameter);
158+
159+
auto assign = assigns_by_var_id[variable_id.variable_name];
160+
auto result = ::set_name(std::make_shared<v0::Result>(assign->input_value(0)), variable_id.output_name);
161+
162+
model->remove_sink(assign); // Don't do replace_node(assign, result)! It will lead to silently incorrect model.
163+
model->remove_variable(model->get_variable_by_id(variable_id.variable_name));
164+
new_parameters.push_back(parameter);
165+
new_results.push_back(result);
166+
}
167+
168+
model->add_parameters(new_parameters);
169+
model->add_results(new_results);
170+
171+
return true;
172+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
hf-internal-testing/tiny-random-LlamaForCausalLM,https://huggingface.co/trl-internal-testing/tiny-random-LlamaForCausalLM
2+
hf-internal-testing/tiny-random-StableLmForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-StableLmForCausalLM
3+
hf-internal-testing/tiny-random-PhiForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-PhiForCausalLM
4+
hf-internal-testing/tiny-random-CodeGenForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-CodeGenForCausalLM
5+
hf-internal-testing/tiny-random-Starcoder2ForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-Starcoder2ForCausalLM
6+
hf-internal-testing/tiny-random-OPTForCausalLM,https://huggingface.co/hf-internal-testing/tiny-random-OPTForCausalLM
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (C) 2018-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import openvino as ov
5+
from openvino._offline_transformations import stateful_to_stateless_transformation
6+
from optimum.intel import OVModelForCausalLM
7+
import models_hub_common.utils as utils
8+
import pytest
9+
import os
10+
11+
def get_read_value_ops(model: ov.Model):
12+
return [op for op in model.get_ops() if op.get_type_name() == 'ReadValue']
13+
14+
def check_desc_tensors(tensors1, tensors2):
15+
# order of tensors may not match, comparing by the total amount and names
16+
assert len(tensors1) == len(tensors2)
17+
assert set(tuple(t.names) for t in tensors1) == set(tuple(t.names) for t in tensors2)
18+
for t1 in tensors1:
19+
t2_candidates = [t for t in tensors2 if t1.names & t.names]
20+
assert len(t2_candidates) == 1
21+
t2 = t2_candidates[0]
22+
assert t1.names == t2.names
23+
assert t1.get_partial_shape() == t2.get_partial_shape()
24+
assert t1.get_element_type() == t2.get_element_type()
25+
26+
def run_stateful_to_stateless_in_runtime(tmp_path, model_id, model_link):
27+
model = OVModelForCausalLM.from_pretrained(model_id, export=True, stateful=True, compile=False)
28+
assert len(model.model.get_sinks()), f"Input model is not in the expected stateful form because it doesn't have any sinks."
29+
assert len(get_read_value_ops(model.model)), f"Input model is not in the expected stateful form because it doesn't have any ReadValue operations."
30+
31+
stateful_to_stateless_transformation(model.model)
32+
33+
sink_ops = model.model.get_sinks()
34+
read_value_ops = get_read_value_ops(model.model)
35+
assert len(sink_ops) == 0, f"Expected stateless model, but there are sinks found: {sink_ops}"
36+
assert len(read_value_ops) == 0, f"Expected stateless model, but there are ReadValue operations found: {read_value_ops}"
37+
38+
stateless_model = OVModelForCausalLM.from_pretrained(model_id, export=True, stateful=False, compile=False)
39+
40+
print(model.model)
41+
print(stateless_model.model)
42+
check_desc_tensors(model.model.inputs, stateless_model.model.inputs)
43+
check_desc_tensors(model.model.outputs, stateless_model.model.outputs)
44+
45+
core = ov.Core()
46+
core.compile_model(model.model, 'CPU')
47+
48+
49+
@pytest.mark.precommit
50+
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "tiny-set-stateful-models-precommit")))
51+
def test_stateful_to_stateless_precommit(tmp_path, model_name, model_link, mark, reason, ie_device):
52+
assert mark is None or mark == 'skip' or mark == 'xfail', \
53+
"Incorrect test case: {}, {}".format(model_name, model_link)
54+
if mark == 'skip':
55+
pytest.skip(reason)
56+
elif mark == 'xfail':
57+
pytest.xfail(reason)
58+
run_stateful_to_stateless_in_runtime(tmp_path, model_name, model_link)

0 commit comments

Comments
 (0)