5
5
#pragma once
6
6
7
7
#include " snippets/lowered/linear_ir.hpp"
8
+ #include < typeinfo>
8
9
#if defined(SNIPPETS_DEBUG_CAPS) && !defined(_WIN32)
9
10
#include < cxxabi.h>
10
11
#endif
@@ -27,14 +28,11 @@ class KernelExecutorBase {
27
28
virtual bool is_completed () const = 0;
28
29
29
30
/* ** Return deep copy of the config */
30
- virtual std::shared_ptr <GenericConfig> clone () const = 0;
31
+ virtual std::unique_ptr <GenericConfig> get_clone_ptr () const = 0;
31
32
32
33
/* ** Compute hash for fast comparison operations or caching support */
33
34
virtual size_t hash () const = 0;
34
35
35
- bool operator ==(const GenericConfig& rhs) const { return hash () == rhs.hash (); }
36
- bool operator !=(const GenericConfig& rhs) const { return hash () != rhs.hash (); }
37
-
38
36
virtual ~GenericConfig () = default ;
39
37
/* * serialize config for debug purposes */
40
38
#ifdef SNIPPETS_DEBUG_CAPS
@@ -45,14 +43,14 @@ class KernelExecutorBase {
45
43
* @brief Update current kernel config in accordance with the passed expression. Corresponding kernel is recompiled if necessary.
46
44
* This method should be called to update KernelExecutor based on runtime info (e.g. shapes) available through expression ptr
47
45
*/
48
- virtual void update_by_expression (const ov::snippets:: lowered::ExpressionPtr& expr) = 0;
46
+ virtual void update_by_expression (const lowered::ExpressionPtr& expr) = 0;
49
47
/* *
50
48
* @brief Replace current kernel config with the provided value. Corresponding kernel is recompiled if necessary.
51
49
* This method should be called to restore a saved state of the executor, that was configured using update_by_expression().
52
50
*/
53
- virtual void update_by_config (const std::shared_ptr< const GenericConfig> & new_config) = 0;
51
+ virtual void update_by_config (const GenericConfig& new_config) = 0;
54
52
55
- virtual std::shared_ptr< const GenericConfig> get_config () const = 0;
53
+ virtual const GenericConfig& get_config () const = 0;
56
54
/* * serialize for debug purposes */
57
55
#ifdef SNIPPETS_DEBUG_CAPS
58
56
virtual std::string to_string () const = 0;
@@ -67,27 +65,27 @@ class KernelExecutorBase {
67
65
68
66
template <typename Conf, typename KernelType,
69
67
typename std::enable_if<std::is_base_of<KernelExecutorBase::GenericConfig, Conf>::value, bool >::type = true >
70
- class KernelExecutor : public snippets :: KernelExecutorBase {
68
+ class KernelExecutor : public KernelExecutorBase {
71
69
public:
72
- explicit KernelExecutor (std::shared_ptr< Conf> c) : KernelExecutorBase(), m_config{std::move (c)} {}
70
+ explicit KernelExecutor (Conf c) : KernelExecutorBase(), m_config{std::move (c)} {}
73
71
74
72
// Note: override when final is redundant, but needed to avoid warnings on some compilers
75
- void update_by_expression (const ov::snippets::lowered::ExpressionPtr& expr) override final { // NOLINT
76
- m_config = std::static_pointer_cast<Conf>(m_config->clone ());
73
+ void update_by_expression (const lowered::ExpressionPtr& expr) override final { // NOLINT
77
74
update_config (expr, m_config);
78
- OPENVINO_ASSERT (m_config && m_config-> is_completed (), " Failed to update kernel config in update_by_expression" );
75
+ OPENVINO_ASSERT (m_config. is_completed (), " Failed to update kernel config in update_by_expression" );
79
76
update_kernel (m_config, m_kernel);
80
77
OPENVINO_ASSERT (m_kernel, " Failed to compile kernel executor" );
81
78
}
82
- void update_by_config (const std::shared_ptr< const GenericConfig> & new_config) override final { // NOLINT
83
- if (* m_config == * new_config)
79
+ void update_by_config (const GenericConfig& new_config) override final { // NOLINT
80
+ if (m_config. hash () == new_config. hash () )
84
81
return ;
85
- m_config = std::static_pointer_cast<Conf>(std::const_pointer_cast<GenericConfig>(new_config));
86
- OPENVINO_ASSERT (m_config && m_config->is_completed (), " Failed to update kernel config in get_config" );
82
+ const auto & new_ptr = dynamic_cast <const Conf*>(&new_config);
83
+ OPENVINO_ASSERT (new_config.is_completed () && new_ptr, " Failed to update kernel config in get_config" );
84
+ m_config = *new_ptr;
87
85
update_kernel (m_config, m_kernel);
88
86
OPENVINO_ASSERT (m_kernel, " Failed to compile kernel executor" );
89
87
}
90
- std::shared_ptr< const GenericConfig> get_config () const override { return m_config; }
88
+ const GenericConfig& get_config () const override { return m_config; }
91
89
std::shared_ptr<const KernelType> get_kernel () const { return m_kernel; }
92
90
#ifdef SNIPPETS_DEBUG_CAPS
93
91
std::string to_string () const override {
@@ -99,20 +97,20 @@ class KernelExecutor : public snippets::KernelExecutorBase {
99
97
std::free);
100
98
type_name = demangled_name.get ();
101
99
#endif
102
- return " KernelExecutorType: " + std::string (type_name) + " KernelConfig: " + m_config-> to_string ();
100
+ return " KernelExecutorType: " + std::string (type_name) + " KernelConfig: " + m_config. to_string ();
103
101
}
104
102
#endif
105
103
106
104
protected:
107
105
/* ** Updates stored kernel config based on runtime info from expression (e.g. new input shapes). */
108
- virtual void update_config (const ov::snippets:: lowered::ExpressionPtr& expr, std::shared_ptr< Conf> & config) const = 0;
106
+ virtual void update_config (const lowered::ExpressionPtr& expr, Conf& config) const = 0;
109
107
/* ** Updates stored kernel in accordance with the passed config. Recompilation of the kernel is
110
- * performed only if necessary, otherwise an appropriate kernel is retrieved from cache . */
111
- virtual void update_kernel (const std::shared_ptr< const Conf> & c, std::shared_ptr<KernelType>& kernel) const = 0;
108
+ * performed if necessary. */
109
+ virtual void update_kernel (const Conf& c, std::shared_ptr<KernelType>& kernel) const = 0;
112
110
113
111
private:
114
112
/* * Contains all the necessary information to compile a desired kernel*/
115
- std::shared_ptr< Conf> m_config = nullptr ;
113
+ Conf m_config {} ;
116
114
/* * Stores pointer to compiled kernel since the last update_kernel() call */
117
115
std::shared_ptr<KernelType> m_kernel = nullptr ;
118
116
};
@@ -122,13 +120,12 @@ class KernelExecutorTable {
122
120
/* ** Register KernelExecutor in the KernelExecutorTable so it can be later updated in runtime. */
123
121
template <typename T, class ...C,
124
122
typename std::enable_if<std::is_base_of<KernelExecutorBase, T>::value, bool >::type = true >
125
- std::shared_ptr<T> register_kernel (const snippets::lowered::ExpressionPtr& expr, C... args) {
126
- OPENVINO_ASSERT (!m_table.count (expr), " This expression already has an alterable kernel" );
123
+ std::shared_ptr<T> register_kernel (const lowered::ExpressionPtr& expr, C... args) {
127
124
const auto & instance = std::make_shared<T>(args...);
128
- m_table[ expr] = instance;
125
+ OPENVINO_ASSERT ( m_table. insert ({ expr, instance}). second , " This expression already has an alterable kernel " ) ;
129
126
return instance;
130
127
}
131
- std::shared_ptr<KernelExecutorBase> get_kernel_executor (const snippets:: lowered::ExpressionPtr& expr) const {
128
+ const std::shared_ptr<KernelExecutorBase>& get_kernel_executor (const lowered::ExpressionPtr& expr) const {
132
129
OPENVINO_ASSERT (m_table.count (expr), " This expression doesn't have a registered kernel executor" );
133
130
return m_table.at (expr);
134
131
}
@@ -150,13 +147,13 @@ class KernelExecutorTable {
150
147
* be accessible from RuntimeConfigurator. In order to replace these cloned ExpressionPtrs with the original ones,
151
148
* we need to call this method.
152
149
*/
153
- void replace_key_expression (const snippets:: lowered::ExpressionPtr& from, const snippets:: lowered::ExpressionPtr& to);
150
+ void replace_key_expression (const lowered::ExpressionPtr& from, const lowered::ExpressionPtr& to);
154
151
155
152
virtual ~KernelExecutorTable () = default ;
156
153
157
154
protected:
158
- std::unordered_map<snippets:: lowered::ExpressionPtr, std::shared_ptr<KernelExecutorBase>> m_table{};
159
- typedef std::vector<std::pair<snippets:: lowered::ExpressionPtr, std::shared_ptr<const KernelExecutorBase::GenericConfig>>> ExecTableState;
155
+ std::unordered_map<lowered::ExpressionPtr, std::shared_ptr<KernelExecutorBase>> m_table{};
156
+ typedef std::vector<std::pair<lowered::ExpressionPtr, std::shared_ptr<const KernelExecutorBase::GenericConfig>>> ExecTableState;
160
157
161
158
/* ** Restore the table state previously obtained by get_state() */
162
159
void reset_state (const ExecTableState& state);
0 commit comments