4
4
5
5
#pragma once
6
6
7
- #include " snippets/lowered/expression.hpp"
8
-
7
+ #include " snippets/lowered/linear_ir.hpp"
8
+ #if defined(SNIPPETS_DEBUG_CAPS) && !defined(_WIN32)
9
+ #include < cxxabi.h>
10
+ #endif
9
11
namespace ov {
10
12
namespace snippets {
11
13
@@ -23,8 +25,38 @@ class KernelExecutorBase {
23
25
* while dynamic kernels will be completed only in runtime, when all the shapes are known.
24
26
*/
25
27
virtual bool is_completed () const = 0;
28
+
29
+ /* ** Return deep copy of the config */
30
+ virtual std::shared_ptr<GenericConfig> clone () const = 0;
31
+
32
+ /* ** Compute hash for fast comparison operations or caching support */
33
+ virtual size_t hash () const = 0;
34
+
35
+ bool operator ==(const GenericConfig& rhs) const { return hash () == rhs.hash (); }
36
+ bool operator !=(const GenericConfig& rhs) const { return hash () != rhs.hash (); }
37
+
26
38
virtual ~GenericConfig () = default ;
39
+ /* * serialize config for debug purposes */
40
+ #ifdef SNIPPETS_DEBUG_CAPS
41
+ virtual std::string to_string () const = 0;
42
+ #endif
27
43
};
44
+ /* *
45
+ * @brief Update current kernel config in accordance with the passed expression. Corresponding kernel is recompiled if necessary.
46
+ * This method should be called to update KernelExecutor based on runtime info (e.g. shapes) available through expression ptr
47
+ */
48
+ virtual void update_by_expression (const ov::snippets::lowered::ExpressionPtr& expr) = 0;
49
+ /* *
50
+ * @brief Replace current kernel config with the provided value. Corresponding kernel is recompiled if necessary.
51
+ * This method should be called to restore a saved state of the executor, that was configured using update_by_expression().
52
+ */
53
+ virtual void update_by_config (const std::shared_ptr<const GenericConfig>& new_config) = 0;
54
+
55
+ virtual std::shared_ptr<const GenericConfig> get_config () const = 0;
56
+ /* * serialize for debug purposes */
57
+ #ifdef SNIPPETS_DEBUG_CAPS
58
+ virtual std::string to_string () const = 0;
59
+ #endif
28
60
virtual ~KernelExecutorBase () = default ;
29
61
30
62
private:
@@ -38,17 +70,47 @@ template<typename Conf, typename KernelType,
38
70
class KernelExecutor : public snippets ::KernelExecutorBase {
39
71
public:
40
72
explicit KernelExecutor (std::shared_ptr<Conf> c) : KernelExecutorBase(), m_config{std::move (c)} {}
41
- /* *
42
- * @brief check current config and recompile kernel if necessary. Use kernel caching to avoid redundant recompilations.
43
- * This method must be called only for complete configs. It's the user responsibility to check is_completed() before calling.
44
- */
45
- virtual void update_kernel () = 0;
73
+
74
+ // Note: override when final is redundant, but needed to avoid warnings on some compilers
75
+ void update_by_expression (const ov::snippets::lowered::ExpressionPtr& expr) override final { // NOLINT
76
+ m_config = std::static_pointer_cast<Conf>(m_config->clone ());
77
+ update_config (expr, m_config);
78
+ OPENVINO_ASSERT (m_config && m_config->is_completed (), " Failed to update kernel config in update_by_expression" );
79
+ update_kernel (m_config, m_kernel);
80
+ OPENVINO_ASSERT (m_kernel, " Failed to compile kernel executor" );
81
+ }
82
+ void update_by_config (const std::shared_ptr<const GenericConfig>& new_config) override final { // NOLINT
83
+ if (*m_config == *new_config)
84
+ return ;
85
+ m_config = std::static_pointer_cast<Conf>(std::const_pointer_cast<GenericConfig>(new_config));
86
+ OPENVINO_ASSERT (m_config && m_config->is_completed (), " Failed to update kernel config in get_config" );
87
+ update_kernel (m_config, m_kernel);
88
+ OPENVINO_ASSERT (m_kernel, " Failed to compile kernel executor" );
89
+ }
90
+ std::shared_ptr<const GenericConfig> get_config () const override { return m_config; }
91
+ std::shared_ptr<const KernelType> get_kernel () const { return m_kernel; }
92
+ #ifdef SNIPPETS_DEBUG_CAPS
93
+ std::string to_string () const override {
94
+ std::string type_name = typeid (KernelType).name ();
95
+ #ifndef _WIN32
96
+ int status;
97
+ std::unique_ptr<char , void (*)(void *)> demangled_name (
98
+ abi::__cxa_demangle (type_name.c_str (), nullptr , nullptr , &status),
99
+ std::free);
100
+ type_name = demangled_name.get ();
101
+ #endif
102
+ return " KernelExecutorType: " + std::string (type_name) + " KernelConfig: " + m_config->to_string ();
103
+ }
104
+ #endif
105
+
46
106
protected:
47
- /* *
48
- * @brief Takes shared_ptr to compilation config, returns shared_ptr to compiled kernel.
49
- * Should be called only if actual compilation is required. Kernel caching must be implemented in update_kernel().
50
- */
51
- virtual std::shared_ptr<KernelType> compile_kernel (const std::shared_ptr<Conf>& c) const = 0;
107
+ /* ** Updates stored kernel config based on runtime info from expression (e.g. new input shapes). */
108
+ virtual void update_config (const ov::snippets::lowered::ExpressionPtr& expr, std::shared_ptr<Conf>& config) const = 0;
109
+ /* ** Updates stored kernel in accordance with the passed config. Recompilation of the kernel is
110
+ * performed only if necessary, otherwise an appropriate kernel is retrieved from cache. */
111
+ virtual void update_kernel (const std::shared_ptr<const Conf>& c, std::shared_ptr<KernelType>& kernel) const = 0;
112
+
113
+ private:
52
114
/* * Contains all the necessary information to compile a desired kernel*/
53
115
std::shared_ptr<Conf> m_config = nullptr ;
54
116
/* * Stores pointer to compiled kernel since the last update_kernel() call */
@@ -57,6 +119,7 @@ class KernelExecutor : public snippets::KernelExecutorBase {
57
119
58
120
class KernelExecutorTable {
59
121
public:
122
+ /* ** Register KernelExecutor in the KernelExecutorTable so it can be later updated in runtime. */
60
123
template <typename T, class ...C,
61
124
typename std::enable_if<std::is_base_of<KernelExecutorBase, T>::value, bool >::type = true >
62
125
std::shared_ptr<T> register_kernel (const snippets::lowered::ExpressionPtr& expr, C... args) {
@@ -69,10 +132,37 @@ class KernelExecutorTable {
69
132
OPENVINO_ASSERT (m_table.count (expr), " This expression doesn't have a registered kernel executor" );
70
133
return m_table.at (expr);
71
134
}
135
+ /* ** Updates every registered KernelExecutor in accordance with the corresponding expression */
136
+ void update_state () const {
137
+ for (const auto & record : m_table)
138
+ record.second ->update_by_expression (record.first );
139
+ }
140
+
141
+ /* ** Returns lambda function that contains current state of the table, and restores this state when called */
142
+ std::function<void ()> get_state_reset () {
143
+ auto current_state = get_state ();
144
+ return [=]() { reset_state (current_state); };
145
+ }
146
+
147
+ /* *
148
+ * @brief Replace originally registered ExpressionPtr with a new value.
149
+ * Note that code emission is performed on a copy of LIR, so all expression pointers visible from emitters won't
150
+ * be accessible from RuntimeConfigurator. In order to replace these cloned ExpressionPtrs with the original ones,
151
+ * we need to call this method.
152
+ */
153
+ void replace_key_expression (const snippets::lowered::ExpressionPtr& from, const snippets::lowered::ExpressionPtr& to);
154
+
72
155
virtual ~KernelExecutorTable () = default ;
73
156
74
157
protected:
75
158
std::unordered_map<snippets::lowered::ExpressionPtr, std::shared_ptr<KernelExecutorBase>> m_table{};
159
+ typedef std::vector<std::pair<snippets::lowered::ExpressionPtr, std::shared_ptr<const KernelExecutorBase::GenericConfig>>> ExecTableState;
160
+
161
+ /* ** Restore the table state previously obtained by get_state() */
162
+ void reset_state (const ExecTableState& state);
163
+
164
+ /* ** Return cumulative state of all the executors in the table. The returned ExecTableState object can be passed to reset_state */
165
+ ExecTableState get_state () const ;
76
166
};
77
167
78
168
using KernelExecutorTablePtr = std::shared_ptr<KernelExecutorTable>;
0 commit comments