Skip to content

Commit b660da8

Browse files
Integrate recompilation infrastructure into RuntimeConfigurator (openvinotoolkit#24955)
### Details: - *Integrate dynamic executors recompilation infrastructure into RuntimeConfigurator* - *Allow RuntimeConfigurator to recompile dynamic kernel executors in runtime* - *Employ this approach to enable dynamic MatMul tests (fp32)* ### Tickets: - *143257*
1 parent 080f22e commit b660da8

31 files changed

+517
-320
lines changed

src/common/snippets/include/snippets/kernel_executor_table.hpp

+102-12
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
#pragma once
66

7-
#include "snippets/lowered/expression.hpp"
8-
7+
#include "snippets/lowered/linear_ir.hpp"
8+
#if defined(SNIPPETS_DEBUG_CAPS) && !defined(_WIN32)
9+
#include <cxxabi.h>
10+
#endif
911
namespace ov {
1012
namespace snippets {
1113

@@ -23,8 +25,38 @@ class KernelExecutorBase {
2325
* while dynamic kernels will be completed only in runtime, when all the shapes are known.
2426
*/
2527
virtual bool is_completed() const = 0;
28+
29+
/*** Return deep copy of the config */
30+
virtual std::shared_ptr<GenericConfig> clone() const = 0;
31+
32+
/*** Compute hash for fast comparison operations or caching support */
33+
virtual size_t hash() const = 0;
34+
35+
bool operator==(const GenericConfig& rhs) const { return hash() == rhs.hash(); }
36+
bool operator!=(const GenericConfig& rhs) const { return hash() != rhs.hash(); }
37+
2638
virtual ~GenericConfig() = default;
39+
/** serialize config for debug purposes */
40+
#ifdef SNIPPETS_DEBUG_CAPS
41+
virtual std::string to_string() const = 0;
42+
#endif
2743
};
44+
/**
45+
* @brief Update current kernel config in accordance with the passed expression. Corresponding kernel is recompiled if necessary.
46+
* This method should be called to update KernelExecutor based on runtime info (e.g. shapes) available through expression ptr
47+
*/
48+
virtual void update_by_expression(const ov::snippets::lowered::ExpressionPtr& expr) = 0;
49+
/**
50+
* @brief Replace current kernel config with the provided value. Corresponding kernel is recompiled if necessary.
51+
* This method should be called to restore a saved state of the executor, that was configured using update_by_expression().
52+
*/
53+
virtual void update_by_config(const std::shared_ptr<const GenericConfig>& new_config) = 0;
54+
55+
virtual std::shared_ptr<const GenericConfig> get_config() const = 0;
56+
/** serialize for debug purposes */
57+
#ifdef SNIPPETS_DEBUG_CAPS
58+
virtual std::string to_string() const = 0;
59+
#endif
2860
virtual ~KernelExecutorBase() = default;
2961

3062
private:
@@ -38,17 +70,47 @@ template<typename Conf, typename KernelType,
3870
class KernelExecutor : public snippets::KernelExecutorBase {
3971
public:
4072
explicit KernelExecutor(std::shared_ptr<Conf> c) : KernelExecutorBase(), m_config{std::move(c)} {}
41-
/**
42-
* @brief check current config and recompile kernel if necessary. Use kernel caching to avoid redundant recompilations.
43-
* This method must be called only for complete configs. It's the user responsibility to check is_completed() before calling.
44-
*/
45-
virtual void update_kernel() = 0;
73+
74+
// Note: override when final is redundant, but needed to avoid warnings on some compilers
75+
void update_by_expression(const ov::snippets::lowered::ExpressionPtr& expr) override final { // NOLINT
76+
m_config = std::static_pointer_cast<Conf>(m_config->clone());
77+
update_config(expr, m_config);
78+
OPENVINO_ASSERT(m_config && m_config->is_completed(), "Failed to update kernel config in update_by_expression");
79+
update_kernel(m_config, m_kernel);
80+
OPENVINO_ASSERT(m_kernel, "Failed to compile kernel executor");
81+
}
82+
void update_by_config(const std::shared_ptr<const GenericConfig>& new_config) override final { // NOLINT
83+
if (*m_config == *new_config)
84+
return;
85+
m_config = std::static_pointer_cast<Conf>(std::const_pointer_cast<GenericConfig>(new_config));
86+
OPENVINO_ASSERT(m_config && m_config->is_completed(), "Failed to update kernel config in get_config");
87+
update_kernel(m_config, m_kernel);
88+
OPENVINO_ASSERT(m_kernel, "Failed to compile kernel executor");
89+
}
90+
std::shared_ptr<const GenericConfig> get_config() const override { return m_config; }
91+
std::shared_ptr<const KernelType> get_kernel() const { return m_kernel; }
92+
#ifdef SNIPPETS_DEBUG_CAPS
93+
std::string to_string() const override {
94+
std::string type_name = typeid(KernelType).name();
95+
#ifndef _WIN32
96+
int status;
97+
std::unique_ptr<char, void (*)(void*)> demangled_name(
98+
abi::__cxa_demangle(type_name.c_str(), nullptr, nullptr, &status),
99+
std::free);
100+
type_name = demangled_name.get();
101+
#endif
102+
return "KernelExecutorType: " + std::string(type_name) + " KernelConfig: " + m_config->to_string();
103+
}
104+
#endif
105+
46106
protected:
47-
/**
48-
* @brief Takes shared_ptr to compilation config, returns shared_ptr to compiled kernel.
49-
* Should be called only if actual compilation is required. Kernel caching must be implemented in update_kernel().
50-
*/
51-
virtual std::shared_ptr<KernelType> compile_kernel(const std::shared_ptr<Conf>& c) const = 0;
107+
/*** Updates stored kernel config based on runtime info from expression (e.g. new input shapes). */
108+
virtual void update_config(const ov::snippets::lowered::ExpressionPtr& expr, std::shared_ptr<Conf>& config) const = 0;
109+
/*** Updates stored kernel in accordance with the passed config. Recompilation of the kernel is
110+
* performed only if necessary, otherwise an appropriate kernel is retrieved from cache. */
111+
virtual void update_kernel(const std::shared_ptr<const Conf>& c, std::shared_ptr<KernelType>& kernel) const = 0;
112+
113+
private:
52114
/** Contains all the necessary information to compile a desired kernel*/
53115
std::shared_ptr<Conf> m_config = nullptr;
54116
/** Stores pointer to compiled kernel since the last update_kernel() call */
@@ -57,6 +119,7 @@ class KernelExecutor : public snippets::KernelExecutorBase {
57119

58120
class KernelExecutorTable {
59121
public:
122+
/*** Register KernelExecutor in the KernelExecutorTable so it can be later updated in runtime. */
60123
template<typename T, class ...C,
61124
typename std::enable_if<std::is_base_of<KernelExecutorBase, T>::value, bool>::type = true>
62125
std::shared_ptr<T> register_kernel(const snippets::lowered::ExpressionPtr& expr, C... args) {
@@ -69,10 +132,37 @@ class KernelExecutorTable {
69132
OPENVINO_ASSERT(m_table.count(expr), "This expression doesn't have a registered kernel executor");
70133
return m_table.at(expr);
71134
}
135+
/*** Updates every registered KernelExecutor in accordance with the corresponding expression */
136+
void update_state() const {
137+
for (const auto& record : m_table)
138+
record.second->update_by_expression(record.first);
139+
}
140+
141+
/*** Returns lambda function that contains current state of the table, and restores this state when called */
142+
std::function<void()> get_state_reset() {
143+
auto current_state = get_state();
144+
return [=]() { reset_state(current_state); };
145+
}
146+
147+
/**
148+
* @brief Replace originally registered ExpressionPtr with a new value.
149+
* Note that code emission is performed on a copy of LIR, so all expression pointers visible from emitters won't
150+
* be accessible from RuntimeConfigurator. In order to replace these cloned ExpressionPtrs with the original ones,
151+
* we need to call this method.
152+
*/
153+
void replace_key_expression(const snippets::lowered::ExpressionPtr& from, const snippets::lowered::ExpressionPtr& to);
154+
72155
virtual ~KernelExecutorTable() = default;
73156

74157
protected:
75158
std::unordered_map<snippets::lowered::ExpressionPtr, std::shared_ptr<KernelExecutorBase>> m_table{};
159+
typedef std::vector<std::pair<snippets::lowered::ExpressionPtr, std::shared_ptr<const KernelExecutorBase::GenericConfig>>> ExecTableState;
160+
161+
/*** Restore the table state previously obtained by get_state() */
162+
void reset_state(const ExecTableState& state);
163+
164+
/*** Return cumulative state of all the executors in the table. The returned ExecTableState object can be passed to reset_state */
165+
ExecTableState get_state() const;
76166
};
77167

78168
using KernelExecutorTablePtr = std::shared_ptr<KernelExecutorTable>;

src/common/snippets/include/snippets/lowered/linear_ir_builder.hpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,14 @@ class LinearIRBuilder {
2929
/**
3030
* @brief Make a full copy of LinearIR by rules described in `m_config`
3131
* @param linear_ir Linear IR
32+
* @param expression_map expression map
3233
* @return clone of `linear_ir`
3334
*/
34-
std::shared_ptr<LinearIR> clone(const std::shared_ptr<LinearIR>& linear_ir) const;
35+
std::shared_ptr<LinearIR> clone(const std::shared_ptr<LinearIR>& linear_ir, ExpressionMap& expression_map) const;
36+
inline std::shared_ptr<LinearIR> clone(const std::shared_ptr<LinearIR>& linear_ir) const {
37+
ExpressionMap expression_map;
38+
return clone(linear_ir, expression_map);
39+
}
3540
/**
3641
* @brief Make a copy of LinearIR range by rules described in `m_config`
3742
* @param begin begin iterator of the target range of LinearIR

src/common/snippets/include/snippets/op/brgemm.hpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,16 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op {
5555
protected:
5656
ov::element::Type get_output_type() const;
5757
std::vector<ov::PartialShape> get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const;
58-
ov::PartialShape get_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const;
58+
ov::PartialShape infer_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const;
5959
ov::PartialShape get_planar_output_shape(const ov::PartialShape& output_shape) const;
60-
void compute_block_size_values(size_t blk_size_m, size_t blk_size_k, size_t blk_size_n);
60+
void set_block_size_values(size_t blk_size_m, size_t blk_size_k, size_t blk_size_n);
6161
size_t m_M_blk = 0;
6262
size_t m_K_blk = 0;
6363
size_t m_N_blk = 0;
6464
float m_beta = 0.f;
6565

6666
private:
6767
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c);
68-
void validate_inputs() const;
6968
};
7069

7170
} // namespace op

src/common/snippets/include/snippets/runtime_configurator.hpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#pragma once
66

77
#include "snippets/lowered/linear_ir.hpp"
8+
#include "snippets/kernel_executor_table.hpp"
89
#include "snippets/lowered/pass/pass.hpp"
910

1011
namespace ov {
@@ -42,7 +43,8 @@ class RuntimeConfig {
4243
ov::snippets::VectorDims master_shape = {};
4344

4445
size_t buffer_scratchpad_size = 0;
45-
std::vector<size_t> buffer_cluster_offsets;
46+
std::vector<size_t> buffer_cluster_offsets {};
47+
KernelExecutorTablePtr kernel_executor_table = std::make_shared<ov::snippets::KernelExecutorTable>();
4648
};
4749

4850
/**
@@ -60,6 +62,8 @@ class RuntimeConfigurator {
6062
* @return updated config
6163
*/
6264
const std::shared_ptr<RuntimeConfig>& get_updated_config(const std::shared_ptr<lowered::LinearIR>& linear_ir);
65+
/*** Returns pointer to KernelExecutorTable owned by the config */
66+
const std::shared_ptr<KernelExecutorTable>& get_kernel_executor_table() const { return m_config->kernel_executor_table; }
6367

6468
protected:
6569
/**

src/common/snippets/include/snippets/target_machine.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
#include "emitter.hpp"
1212
#include "snippets/lowered/expression.hpp"
13-
#include "kernel_executor_table.hpp"
1413

1514
namespace ov {
1615
namespace snippets {
@@ -94,7 +93,6 @@ class TargetMachine {
9493

9594
protected:
9695
std::map<const ov::DiscreteTypeInfo, jitters_value> jitters;
97-
std::shared_ptr<KernelExecutorTable> kernel_executor_table;
9896
std::shared_ptr<RuntimeConfigurator> configurator;
9997
};
10098

src/common/snippets/include/snippets/utils.hpp

+38
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,44 @@ std::shared_ptr<ov::Node> get_leaf_node_of_first_child_shape_infer_seq(const std
243243
*/
244244
std::shared_ptr<ov::Node> get_leaf_node_of_first_parent_shape_infer_seq(const std::shared_ptr<ov::Node>& start_node);
245245

246+
/**
247+
* @brief Calculate leading dimension of the shape that should be read according to the layout
248+
* @param shape original (not reordered) input shape
249+
* @param layout specifies the order in what dimensions of in the input shape should be read
250+
* @return stride of the dimension idx = layout[layout.size() - 2] in the original shape
251+
Example:
252+
Original shape (shape) = [1, 49, 2, 23]
253+
Layout (transpose order) = [2, 0, 1, 3]
254+
255+
dim_idx = layout.size() - 2 = 2
256+
// Since layout specifies the order of dimensions in which the shape should be read
257+
dim = layout[dim_idx] = 1
258+
stride(shape[1]) = shape[2] * shape[3] = 2 * 23
259+
*/
260+
size_t get_in_leading_dim(const VectorDims& shape, const std::vector<size_t>& layout);
261+
inline size_t get_in_leading_dim(const lowered::PortDescriptorPtr& pd) {
262+
return get_in_leading_dim(pd->get_shape(), pd->get_layout());
263+
}
264+
/**
265+
*
266+
* @param shape reordered input shape that is stored according to the layout
267+
* @param layout specifies the order in what the dimensions of the input shape are stored
268+
* @return
269+
Output shape is already transposed, we need to correctly write the data with original shape by the order
270+
Example:
271+
Original transposed shape (shape) = [49, 2, 7, 39]
272+
Layout (transpose order) = [2, 0, 1, 3]
273+
274+
dim_idx = layout.size() - 2 = 2
275+
// Since the shape dimensions are already reordered according to the layout
276+
dim = /find dim_idx index in layout/ = 0
277+
stride(shape[0]) = shape[1] * shape[2] * shape[3] = 2 * 7 * 39
278+
*/
279+
size_t get_out_leading_dim(const VectorDims& shape, const std::vector<size_t>& layout);
280+
inline size_t get_out_leading_dim(const lowered::PortDescriptorPtr& pd) {
281+
return get_out_leading_dim(pd->get_shape(), pd->get_layout());
282+
}
283+
246284
} // namespace utils
247285
} // namespace snippets
248286
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "snippets/kernel_executor_table.hpp"
6+
7+
namespace ov {
8+
namespace snippets {
9+
10+
void KernelExecutorTable::replace_key_expression(const snippets::lowered::ExpressionPtr& from, const snippets::lowered::ExpressionPtr& to) {
11+
const auto& found = m_table.find(from);
12+
if (found != m_table.end()) {
13+
OPENVINO_ASSERT(m_table.count(to) == 0, "Attempt to replace a value that is already in the KernelExecutorTable");
14+
m_table.insert({to, found->second});
15+
m_table.erase(found);
16+
}
17+
}
18+
19+
void KernelExecutorTable::reset_state(const ExecTableState& state) {
20+
OPENVINO_ASSERT(state.size() == m_table.size(), "Invalid state in restore_state: size mismatch");
21+
auto state_it = state.begin();
22+
for (const auto& table_record : m_table) {
23+
const auto& state_record = *state_it++;
24+
OPENVINO_ASSERT(table_record.first == state_record.first, "Invalid state in restore_state: expressions mismatch");
25+
table_record.second->update_by_config(state_record.second);
26+
}
27+
}
28+
29+
KernelExecutorTable::ExecTableState KernelExecutorTable::get_state() const {
30+
ExecTableState result;
31+
// Note: we need to clone configs when saving the state, since the configs still stored in the table can
32+
// be modified e.g. by calling update_by_expression();
33+
for (const auto& record : m_table)
34+
result.emplace_back(std::make_pair(record.first, record.second->get_config()->clone()));
35+
return result;
36+
}
37+
38+
}// namespace snippets
39+
}// namespace ov

src/common/snippets/src/lowered/linear_ir_builder.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,10 @@ std::vector<std::shared_ptr<ov::Node>> clone_nodes(const std::vector<std::shared
6565
}
6666
} // namespace
6767

68-
std::shared_ptr<LinearIR> LinearIRBuilder::clone(const std::shared_ptr<LinearIR>& linear_ir) const {
68+
std::shared_ptr<LinearIR> LinearIRBuilder::clone(const std::shared_ptr<LinearIR>& linear_ir, ExpressionMap& expression_map) const {
6969
auto cloned = std::make_shared<LinearIR>();
7070
cloned->m_config = linear_ir->m_config;
7171

72-
ExpressionMap expression_map;
7372
cloned->m_expressions = clone_range(linear_ir->m_expressions.cbegin(), linear_ir->m_expressions.cend(), expression_map);
7473
for (const auto& expr : cloned->m_expressions) {
7574
cloned->register_expression(expr, true);

0 commit comments

Comments
 (0)