|
| 1 | +// Copyright (C) 2018-2024 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | + |
| 5 | +#include "openvino/pass/stateful_to_stateless.hpp" |
| 6 | + |
| 7 | +#include <regex> |
| 8 | +#include <string> |
| 9 | + |
| 10 | +#include "openvino/cc/pass/itt.hpp" |
| 11 | +#include "openvino/op/assign.hpp" |
| 12 | +#include "openvino/op/gather.hpp" |
| 13 | +#include "openvino/op/read_value.hpp" |
| 14 | +#include "openvino/pass/manager.hpp" |
| 15 | +#include "transformations/utils/utils.hpp" |
| 16 | + |
| 17 | +using namespace ov::op; |
| 18 | + |
| 19 | +namespace { |
| 20 | + |
| 21 | +std::shared_ptr<ov::Node> set_name(std::shared_ptr<ov::Node> node, const std::string& name) { |
| 22 | + // Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a |
| 23 | + // given single name) |
| 24 | + node->set_friendly_name(name); |
| 25 | + OPENVINO_ASSERT(node->get_output_size() == 1); |
| 26 | + node->get_output_tensor(0).set_names({name}); |
| 27 | + return node; |
| 28 | +} |
| 29 | + |
| 30 | +// Templated method that has the same effect as not templated `set_name` but saves Op type for convenient calls chaining |
| 31 | +template <typename T> |
| 32 | +inline std::shared_ptr<T> set_name(std::shared_ptr<T> node, const std::string& name) { |
| 33 | + set_name(std::dynamic_pointer_cast<ov::Node>(node), name); |
| 34 | + return node; |
| 35 | +} |
| 36 | + |
| 37 | +std::shared_ptr<v0::Parameter> get_parameter_by_tensor_name(const std::shared_ptr<ov::Model>& model, |
| 38 | + const std::string& name) { |
| 39 | + for (const auto& param : model->get_parameters()) { |
| 40 | + if (param->get_output_tensor(0).get_names().count(name)) |
| 41 | + return param; |
| 42 | + } |
| 43 | + return nullptr; // nullptr and return type are only difference from ov::Model::input(name) |
| 44 | +} |
| 45 | + |
| 46 | +struct Variable { |
| 47 | + struct Context { |
| 48 | + // to hold compiled once regex for all Variable instances |
| 49 | + const std::regex naming_convention = |
| 50 | + std::regex(R"((past_key_values\.(\d+)\.(key|value))(present\.(\d+)\.(key|value)))"); |
| 51 | + }; |
| 52 | + |
| 53 | + Variable(const Context& context, const std::string& variable_name) : variable_name(variable_name) { |
| 54 | + // Try to decode original naming of the corresponding input and output in the stateless model |
| 55 | + std::smatch match; |
| 56 | + if (std::regex_match(variable_name, match, context.naming_convention)) { |
| 57 | + input_name = match[1].str(); |
| 58 | + output_name = match[4].str(); |
| 59 | + auto input_index = match[2].str(); |
| 60 | + auto output_index = match[5].str(); |
| 61 | + if (input_index == output_index && input_index.length() <= std::numeric_limits<int>::digits10) { |
| 62 | + index = std::stoi(input_index) * 2 + int(match[3].str() == "value"); // order key before value |
| 63 | + } else { |
| 64 | + index = -1; |
| 65 | + } |
| 66 | + } else { |
| 67 | + // Variable name doesn't follow the expected naming convention. It doens't prevent forming |
| 68 | + // a correct stateless model but doesn't give a way to restore all names and inputs/outputs ordering |
| 69 | + // accurately. |
| 70 | + input_name = "input_restored." + variable_name; |
| 71 | + output_name = "output_restored." + variable_name; |
| 72 | + index = -1; |
| 73 | + } |
| 74 | + } |
| 75 | + |
| 76 | + int index; // layer index, -1 means the index isn't known |
| 77 | + std::string variable_name; // original variable_id |
| 78 | + std::string input_name; // restored name of input |
| 79 | + std::string output_name; // restored name of output |
| 80 | +}; |
| 81 | + |
| 82 | +typedef std::vector<Variable> Variables; |
| 83 | + |
| 84 | +void restore_kv_cache_order(Variables& variables, const std::unordered_map<std::string, size_t>& var_index_by_var_id) { |
| 85 | + // Try to restore variable order based on the known naming convention from optimum-intel. |
| 86 | + // If names are not satisfy the expected convention, fallback to use order based on var_index_by_var_id |
| 87 | + // Sort items that do satisfy the naming conventions before items that don't satisfy. |
| 88 | + |
| 89 | + std::stable_sort(variables.begin(), variables.end(), [&](const Variable& a, const Variable& b) { |
| 90 | + if (a.index >= 0 && b.index >= 0) { |
| 91 | + return a.index < b.index; |
| 92 | + } else if (a.index >= 0 && b.index < 0) { |
| 93 | + return true; |
| 94 | + } else if (a.index < 0 && b.index >= 0) { |
| 95 | + return false; |
| 96 | + } else { // a.index < 0 && b.index < 0 |
| 97 | + return var_index_by_var_id.at(a.variable_name) < var_index_by_var_id.at(b.variable_name); |
| 98 | + } |
| 99 | + }); |
| 100 | +} |
| 101 | + |
| 102 | +} // namespace |
| 103 | + |
| 104 | +bool ov::pass::StatefulToStateless::run_on_model(const std::shared_ptr<ov::Model>& model) { |
| 105 | + RUN_ON_MODEL_SCOPE(StatefulToStateless); |
| 106 | + |
| 107 | + auto beam_idx = get_parameter_by_tensor_name(model, "beam_idx"); |
| 108 | + Variables variables; // to collect variables corresponding to future_params |
| 109 | + variables.reserve(model->get_sinks().size()); |
| 110 | + Variable::Context context; |
| 111 | + std::unordered_map<std::string, std::shared_ptr<ov::Node>> |
| 112 | + future_params; // to collect nodes, each with a single output that will be replaced by new parameters |
| 113 | + if (beam_idx) { |
| 114 | + for (const ov::Input<ov::Node>& input : beam_idx->get_output_target_inputs(0)) { |
| 115 | + if (auto gather = std::dynamic_pointer_cast<op::util::GatherBase>(input.get_node()->shared_from_this())) { |
| 116 | + auto read_value = |
| 117 | + std::dynamic_pointer_cast<op::util::ReadValueBase>(gather->get_input_node_shared_ptr(0)); |
| 118 | + OPENVINO_ASSERT(read_value, |
| 119 | + "Unexpected model topology in StatefulToStateless: no ReadValue is found at the first " |
| 120 | + "input of Gather by `beam_idx` parameter"); |
| 121 | + auto variable_name = read_value->get_variable_id(); |
| 122 | + variables.push_back(Variable(context, variable_name)); |
| 123 | + future_params[variable_name] = gather; |
| 124 | + } |
| 125 | + } |
| 126 | + } else { |
| 127 | + OPENVINO_THROW( |
| 128 | + "Stateful models without `beam_idx` input are not supported in StatefulToStateless transformation"); |
| 129 | + } |
| 130 | + model->remove_parameter(beam_idx); |
| 131 | + |
| 132 | + typedef std::shared_ptr<op::util::AssignBase> PAssign; |
| 133 | + std::unordered_map<std::string, PAssign> assigns_by_var_id; |
| 134 | + std::unordered_map<std::string, size_t> assign_index_by_var_id; |
| 135 | + const auto& sinks = model->get_sinks(); |
| 136 | + for (size_t i = 0; i < sinks.size(); ++i) { |
| 137 | + if (auto assign = std::dynamic_pointer_cast<op::util::AssignBase>(sinks[i])) { |
| 138 | + const auto& var_id = assign->get_variable_id(); |
| 139 | + assigns_by_var_id[var_id] = assign; |
| 140 | + assign_index_by_var_id[var_id] = i; |
| 141 | + } |
| 142 | + } |
| 143 | + |
| 144 | + restore_kv_cache_order(variables, assign_index_by_var_id); |
| 145 | + |
| 146 | + ov::ParameterVector new_parameters; |
| 147 | + ov::ResultVector new_results; |
| 148 | + new_parameters.reserve(variables.size()); |
| 149 | + new_results.reserve(variables.size()); |
| 150 | + |
| 151 | + for (const auto& variable_id : variables) { |
| 152 | + auto future_param = future_params[variable_id.variable_name]; |
| 153 | + auto parameter = ::set_name(std::make_shared<v0::Parameter>(future_param->get_output_element_type(0), |
| 154 | + future_param->get_output_partial_shape(0)), |
| 155 | + variable_id.input_name); |
| 156 | + |
| 157 | + replace_node(future_param, parameter); |
| 158 | + |
| 159 | + auto assign = assigns_by_var_id[variable_id.variable_name]; |
| 160 | + auto result = ::set_name(std::make_shared<v0::Result>(assign->input_value(0)), variable_id.output_name); |
| 161 | + |
| 162 | + model->remove_sink(assign); // Don't do replace_node(assign, result)! It will lead to silently incorrect model. |
| 163 | + model->remove_variable(model->get_variable_by_id(variable_id.variable_name)); |
| 164 | + new_parameters.push_back(parameter); |
| 165 | + new_results.push_back(result); |
| 166 | + } |
| 167 | + |
| 168 | + model->add_parameters(new_parameters); |
| 169 | + model->add_results(new_results); |
| 170 | + |
| 171 | + return true; |
| 172 | +} |
0 commit comments