Skip to content

Commit fe423b9

Browse files
Enable dynamic MHA tests (#25326)
### Details: - *Enable dynamic MHA and related tests* ### Tickets: - *143261*
1 parent a5f304b commit fe423b9

File tree

23 files changed

+550
-351
lines changed

23 files changed

+550
-351
lines changed

src/common/snippets/include/snippets/kernel_executor_table.hpp

+26-29
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#pragma once
66

77
#include "snippets/lowered/linear_ir.hpp"
8+
#include <typeinfo>
89
#if defined(SNIPPETS_DEBUG_CAPS) && !defined(_WIN32)
910
#include <cxxabi.h>
1011
#endif
@@ -27,14 +28,11 @@ class KernelExecutorBase {
2728
virtual bool is_completed() const = 0;
2829

2930
/*** Return deep copy of the config */
30-
virtual std::shared_ptr<GenericConfig> clone() const = 0;
31+
virtual std::unique_ptr<GenericConfig> get_clone_ptr() const = 0;
3132

3233
/*** Compute hash for fast comparison operations or caching support */
3334
virtual size_t hash() const = 0;
3435

35-
bool operator==(const GenericConfig& rhs) const { return hash() == rhs.hash(); }
36-
bool operator!=(const GenericConfig& rhs) const { return hash() != rhs.hash(); }
37-
3836
virtual ~GenericConfig() = default;
3937
/** serialize config for debug purposes */
4038
#ifdef SNIPPETS_DEBUG_CAPS
@@ -45,14 +43,14 @@ class KernelExecutorBase {
4543
* @brief Update current kernel config in accordance with the passed expression. Corresponding kernel is recompiled if necessary.
4644
* This method should be called to update KernelExecutor based on runtime info (e.g. shapes) available through expression ptr
4745
*/
48-
virtual void update_by_expression(const ov::snippets::lowered::ExpressionPtr& expr) = 0;
46+
virtual void update_by_expression(const lowered::ExpressionPtr& expr) = 0;
4947
/**
5048
* @brief Replace current kernel config with the provided value. Corresponding kernel is recompiled if necessary.
5149
* This method should be called to restore a saved state of the executor, that was configured using update_by_expression().
5250
*/
53-
virtual void update_by_config(const std::shared_ptr<const GenericConfig>& new_config) = 0;
51+
virtual void update_by_config(const GenericConfig& new_config) = 0;
5452

55-
virtual std::shared_ptr<const GenericConfig> get_config() const = 0;
53+
virtual const GenericConfig& get_config() const = 0;
5654
/** serialize for debug purposes */
5755
#ifdef SNIPPETS_DEBUG_CAPS
5856
virtual std::string to_string() const = 0;
@@ -67,27 +65,27 @@ class KernelExecutorBase {
6765

6866
template<typename Conf, typename KernelType,
6967
typename std::enable_if<std::is_base_of<KernelExecutorBase::GenericConfig, Conf>::value, bool>::type = true>
70-
class KernelExecutor : public snippets::KernelExecutorBase {
68+
class KernelExecutor : public KernelExecutorBase {
7169
public:
72-
explicit KernelExecutor(std::shared_ptr<Conf> c) : KernelExecutorBase(), m_config{std::move(c)} {}
70+
explicit KernelExecutor(Conf c) : KernelExecutorBase(), m_config{std::move(c)} {}
7371

7472
// Note: override when final is redundant, but needed to avoid warnings on some compilers
75-
void update_by_expression(const ov::snippets::lowered::ExpressionPtr& expr) override final { // NOLINT
76-
m_config = std::static_pointer_cast<Conf>(m_config->clone());
73+
void update_by_expression(const lowered::ExpressionPtr& expr) override final { // NOLINT
7774
update_config(expr, m_config);
78-
OPENVINO_ASSERT(m_config && m_config->is_completed(), "Failed to update kernel config in update_by_expression");
75+
OPENVINO_ASSERT(m_config.is_completed(), "Failed to update kernel config in update_by_expression");
7976
update_kernel(m_config, m_kernel);
8077
OPENVINO_ASSERT(m_kernel, "Failed to compile kernel executor");
8178
}
82-
void update_by_config(const std::shared_ptr<const GenericConfig>& new_config) override final { // NOLINT
83-
if (*m_config == *new_config)
79+
void update_by_config(const GenericConfig& new_config) override final { // NOLINT
80+
if (m_config.hash() == new_config.hash())
8481
return;
85-
m_config = std::static_pointer_cast<Conf>(std::const_pointer_cast<GenericConfig>(new_config));
86-
OPENVINO_ASSERT(m_config && m_config->is_completed(), "Failed to update kernel config in get_config");
82+
const auto& new_ptr = dynamic_cast<const Conf*>(&new_config);
83+
OPENVINO_ASSERT(new_config.is_completed() && new_ptr, "Failed to update kernel config in get_config");
84+
m_config = *new_ptr;
8785
update_kernel(m_config, m_kernel);
8886
OPENVINO_ASSERT(m_kernel, "Failed to compile kernel executor");
8987
}
90-
std::shared_ptr<const GenericConfig> get_config() const override { return m_config; }
88+
const GenericConfig& get_config() const override { return m_config; }
9189
std::shared_ptr<const KernelType> get_kernel() const { return m_kernel; }
9290
#ifdef SNIPPETS_DEBUG_CAPS
9391
std::string to_string() const override {
@@ -99,20 +97,20 @@ class KernelExecutor : public snippets::KernelExecutorBase {
9997
std::free);
10098
type_name = demangled_name.get();
10199
#endif
102-
return "KernelExecutorType: " + std::string(type_name) + " KernelConfig: " + m_config->to_string();
100+
return "KernelExecutorType: " + std::string(type_name) + " KernelConfig: " + m_config.to_string();
103101
}
104102
#endif
105103

106104
protected:
107105
/*** Updates stored kernel config based on runtime info from expression (e.g. new input shapes). */
108-
virtual void update_config(const ov::snippets::lowered::ExpressionPtr& expr, std::shared_ptr<Conf>& config) const = 0;
106+
virtual void update_config(const lowered::ExpressionPtr& expr, Conf& config) const = 0;
109107
/*** Updates stored kernel in accordance with the passed config. Recompilation of the kernel is
110-
* performed only if necessary, otherwise an appropriate kernel is retrieved from cache. */
111-
virtual void update_kernel(const std::shared_ptr<const Conf>& c, std::shared_ptr<KernelType>& kernel) const = 0;
108+
* performed if necessary. */
109+
virtual void update_kernel(const Conf& c, std::shared_ptr<KernelType>& kernel) const = 0;
112110

113111
private:
114112
/** Contains all the necessary information to compile a desired kernel*/
115-
std::shared_ptr<Conf> m_config = nullptr;
113+
Conf m_config {};
116114
/** Stores pointer to compiled kernel since the last update_kernel() call */
117115
std::shared_ptr<KernelType> m_kernel = nullptr;
118116
};
@@ -122,13 +120,12 @@ class KernelExecutorTable {
122120
/*** Register KernelExecutor in the KernelExecutorTable so it can be later updated in runtime. */
123121
template<typename T, class ...C,
124122
typename std::enable_if<std::is_base_of<KernelExecutorBase, T>::value, bool>::type = true>
125-
std::shared_ptr<T> register_kernel(const snippets::lowered::ExpressionPtr& expr, C... args) {
126-
OPENVINO_ASSERT(!m_table.count(expr), "This expression already has an alterable kernel");
123+
std::shared_ptr<T> register_kernel(const lowered::ExpressionPtr& expr, C... args) {
127124
const auto& instance = std::make_shared<T>(args...);
128-
m_table[expr] = instance;
125+
OPENVINO_ASSERT(m_table.insert({expr, instance}).second, "This expression already has an alterable kernel");
129126
return instance;
130127
}
131-
std::shared_ptr<KernelExecutorBase> get_kernel_executor(const snippets::lowered::ExpressionPtr& expr) const {
128+
const std::shared_ptr<KernelExecutorBase>& get_kernel_executor(const lowered::ExpressionPtr& expr) const {
132129
OPENVINO_ASSERT(m_table.count(expr), "This expression doesn't have a registered kernel executor");
133130
return m_table.at(expr);
134131
}
@@ -150,13 +147,13 @@ class KernelExecutorTable {
150147
* be accessible from RuntimeConfigurator. In order to replace these cloned ExpressionPtrs with the original ones,
151148
* we need to call this method.
152149
*/
153-
void replace_key_expression(const snippets::lowered::ExpressionPtr& from, const snippets::lowered::ExpressionPtr& to);
150+
void replace_key_expression(const lowered::ExpressionPtr& from, const lowered::ExpressionPtr& to);
154151

155152
virtual ~KernelExecutorTable() = default;
156153

157154
protected:
158-
std::unordered_map<snippets::lowered::ExpressionPtr, std::shared_ptr<KernelExecutorBase>> m_table{};
159-
typedef std::vector<std::pair<snippets::lowered::ExpressionPtr, std::shared_ptr<const KernelExecutorBase::GenericConfig>>> ExecTableState;
155+
std::unordered_map<lowered::ExpressionPtr, std::shared_ptr<KernelExecutorBase>> m_table{};
156+
typedef std::vector<std::pair<lowered::ExpressionPtr, std::shared_ptr<const KernelExecutorBase::GenericConfig>>> ExecTableState;
160157

161158
/*** Restore the table state previously obtained by get_state() */
162159
void reset_state(const ExecTableState& state);

src/common/snippets/include/snippets/utils.hpp

+5-34
Original file line numberDiff line numberDiff line change
@@ -247,43 +247,14 @@ std::shared_ptr<ov::Node> get_leaf_node_of_first_child_shape_infer_seq(const std
247247
*/
248248
std::shared_ptr<ov::Node> get_leaf_node_of_first_parent_shape_infer_seq(const std::shared_ptr<ov::Node>& start_node);
249249

250-
/**
251-
* @brief Calculate leading dimension of the shape that should be read according to the layout
252-
* @param shape original (not reordered) input shape
253-
* @param layout specifies the order in what dimensions of in the input shape should be read
254-
* @return stride of the dimension idx = layout[layout.size() - 2] in the original shape
255-
Example:
256-
Original shape (shape) = [1, 49, 2, 23]
257-
Layout (transpose order) = [2, 0, 1, 3]
258-
259-
dim_idx = layout.size() - 2 = 2
260-
// Since layout specifies the order of dimensions in which the shape should be read
261-
dim = layout[dim_idx] = 1
262-
stride(shape[1]) = shape[2] * shape[3] = 2 * 23
263-
*/
264-
size_t get_in_leading_dim(const VectorDims& shape, const std::vector<size_t>& layout);
265-
inline size_t get_in_leading_dim(const lowered::PortDescriptorPtr& pd) {
266-
return get_in_leading_dim(pd->get_shape(), pd->get_layout());
267-
}
268250
/**
269251
*
270-
* @param shape reordered input shape that is stored according to the layout
271-
* @param layout specifies the order in what the dimensions of the input shape are stored
272-
* @return
273-
Output shape is already transposed, we need to correctly write the data with original shape by the order
274-
Example:
275-
Original transposed shape (shape) = [49, 2, 7, 39]
276-
Layout (transpose order) = [2, 0, 1, 3]
277-
278-
dim_idx = layout.size() - 2 = 2
279-
// Since the shape dimensions are already reordered according to the layout
280-
dim = /find dim_idx index in layout/ = 0
281-
stride(shape[0]) = shape[1] * shape[2] * shape[3] = 2 * 7 * 39
252+
* @param Get stride of input/output dimension
253+
* @param expr_port target port that contains shape and layout info
254+
* @param idx index of the target dimension starting from the shape's end (default = 1)
282255
*/
283-
size_t get_out_leading_dim(const VectorDims& shape, const std::vector<size_t>& layout);
284-
inline size_t get_out_leading_dim(const lowered::PortDescriptorPtr& pd) {
285-
return get_out_leading_dim(pd->get_shape(), pd->get_layout());
286-
}
256+
257+
int64_t get_dim_stride(const lowered::ExpressionPort& expr_port, size_t idx = 1);
287258

288259
} // namespace utils
289260
} // namespace snippets

src/common/snippets/src/kernel_executor_table.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ void KernelExecutorTable::reset_state(const ExecTableState& state) {
2222
for (const auto& table_record : m_table) {
2323
const auto& state_record = *state_it++;
2424
OPENVINO_ASSERT(table_record.first == state_record.first, "Invalid state in restore_state: expressions mismatch");
25-
table_record.second->update_by_config(state_record.second);
25+
table_record.second->update_by_config(*state_record.second);
2626
}
2727
}
2828

@@ -31,7 +31,7 @@ KernelExecutorTable::ExecTableState KernelExecutorTable::get_state() const {
3131
// Note: we need to clone configs when saving the state, since the configs still stored in the table can
3232
// be modified e.g. by calling update_by_expression();
3333
for (const auto& record : m_table)
34-
result.emplace_back(std::make_pair(record.first, record.second->get_config()->clone()));
34+
result.emplace_back(std::make_pair(record.first, record.second->get_config().get_clone_ptr()));
3535
return result;
3636
}
3737

src/common/snippets/src/lowered/pass/compute_buffer_allocation_size.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ std::vector<size_t> get_parent_inner_loops(const std::vector<size_t>& parent_loo
2727
// Ticket: 113744
2828
// TODO: This logic covers only several specific cases so it should be generalized.
2929
size_t ComputeBufferAllocationSize::get_allocation_size(const LoopManagerPtr& loop_manager, const ExpressionPtr& buffer_expr, size_t allocation_rank) {
30-
const auto& parent_port = buffer_expr->get_input_port_connector(0)->get_source();
30+
// Note: Buffer expressions can have more than one parent after the loops splitting transformation, but only the last parent
31+
// can be used to access valid loop ports. More info in the ticket: 146646
32+
const auto& parent_port = buffer_expr->get_input_port_connector(buffer_expr->get_input_count() - 1)->get_source();
3133
const auto& parent_loop_ids = get_parent_inner_loops(parent_port.get_expr()->get_loop_ids(), buffer_expr->get_loop_ids());
3234
const auto planar_shape = utils::get_preordered_vdims(parent_port);
3335

src/common/snippets/src/op/subgraph.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ snippets::Schedule Subgraph::generate(const void* compile_params) const {
546546
auto lowering_result = m_generator->generate(linear_ir, compile_params);
547547

548548
// Note: Since the code emission is performed on a copy of LIR, but RuntimeConfigurator works with the initial instance,
549-
// we need to replace cloned expression pointers to original ones in the KernelExecutorTable
549+
// we need to replace cloned expression pointers to original ones in the KernelExecutorTable. Ticket: 129772
550550
const auto& exec_table = m_generator->get_target_machine()->get_runtime_configurator()->get_kernel_executor_table();
551551
for (const auto& expr : *m_linear_ir)
552552
exec_table->replace_key_expression(expression_map.at(expr.get()), expr);

src/common/snippets/src/utils.cpp

+9-16
Original file line numberDiff line numberDiff line change
@@ -291,22 +291,15 @@ std::shared_ptr<ov::Node> get_leaf_node_of_first_parent_shape_infer_seq(const st
291291
return leaf_node;
292292
}
293293

294-
size_t get_in_leading_dim(const VectorDims& shape, const std::vector<size_t>& layout) {
295-
if (layout.empty())
296-
return shape.back();
297-
OPENVINO_ASSERT(layout.back() == layout.size() - 1 && layout.size() == shape.size(),
298-
"detected invalid layout values: check that this shape + layout combination is schedulable");
299-
const auto idx = static_cast<VectorDims::difference_type>(layout[layout.size() - 2]);
300-
return std::accumulate(shape.cbegin() + idx + 1, shape.end(), 1ull, std::multiplies<size_t>());
301-
}
302-
size_t get_out_leading_dim(const VectorDims& shape, const std::vector<size_t>& layout) {
303-
if (layout.empty())
304-
return shape.back();
305-
OPENVINO_ASSERT(layout.back() == layout.size() - 1 && layout.size() == shape.size(),
306-
"detected invalid layout values: check that this shape + layout combination is schedulable");
307-
const auto idx = layout.size() - 2;
308-
const auto dim = std::distance(layout.cbegin(), std::find(layout.cbegin(), layout.cend(), idx));
309-
return std::accumulate(shape.cbegin() + dim + 1, shape.cend(), 1ull, std::multiplies<size_t>());
294+
int64_t get_dim_stride(const lowered::ExpressionPort& expr_port, size_t idx) {
295+
size_t dim_idx = 0;
296+
const auto& layout = expr_port.get_descriptor_ptr()->get_layout();
297+
switch (expr_port.get_type()) {
298+
case lowered::ExpressionPort::Input: dim_idx = utils::get_input_dim_idx(layout, idx); break;
299+
case lowered::ExpressionPort::Output: dim_idx = utils::get_output_dim_idx(layout, idx); break;
300+
default: OPENVINO_THROW("Unsupported expression port type!");
301+
}
302+
return get_stride(dim_idx, expr_port.get_descriptor_ptr()->get_shape());
310303
}
311304

312305
} // namespace utils

src/plugins/intel_cpu/src/emitters/snippets/cpu_kernel_executor_table.hpp

+11-10
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,10 @@ namespace intel_cpu {
1313
template<typename Conf, typename KernelType>
1414
class CPUKernelExecutor : public snippets::KernelExecutor<Conf, KernelType> {
1515
public:
16-
CPUKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, std::shared_ptr<Conf> c) :
17-
snippets::KernelExecutor<Conf, KernelType>(c), m_kernel_cache(std::move(kernel_cache)) {}
18-
struct Key {
19-
explicit Key(const std::shared_ptr<const Conf>& c) : config{c} {}
20-
const std::shared_ptr<const Conf> config;
21-
size_t hash() const { return config->hash(); }
22-
bool operator==(const Key& rhs) const { return *config == *rhs.config; }
23-
};
24-
void update_kernel(const std::shared_ptr<const Conf>& config, std::shared_ptr<KernelType>& kernel) const override final { // NOLINT
16+
CPUKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, Conf c) :
17+
snippets::KernelExecutor<Conf, KernelType>(std::move(c)), m_kernel_cache(std::move(kernel_cache)) {}
18+
19+
void update_kernel(const Conf& config, std::shared_ptr<KernelType>& kernel) const override final { // NOLINT
2520
const auto& cache = m_kernel_cache.lock();
2621
OPENVINO_ASSERT(cache, "Invalid kernel cache pointer in CPUKernelExecutor::update_kernel()");
2722
const auto& lookup_result = cache->getOrCreate(Key(config),
@@ -32,8 +27,14 @@ class CPUKernelExecutor : public snippets::KernelExecutor<Conf, KernelType> {
3227
}
3328

3429
protected:
30+
struct Key {
31+
explicit Key(Conf c) : config{std::move(c)} {}
32+
const Conf config;
33+
size_t hash() const { return config.hash(); }
34+
bool operator==(const Key& rhs) const { return config == rhs.config; }
35+
};
3536
/** Compile kernel managed by KernelExecutor instance. Will be called only if Kernel is not found in the cache */
36-
virtual std::shared_ptr<KernelType> compile_kernel(const std::shared_ptr<const Conf>& c) const = 0;
37+
virtual std::shared_ptr<KernelType> compile_kernel(const Conf& c) const = 0;
3738
/** CPU plugin cache implementation is used to avoid redundant recompilations */
3839
ov::intel_cpu::MultiCacheWeakPtr m_kernel_cache;
3940
};

src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
4343
size_t leading_dimension = *(original_shape.rbegin());
4444
if (!layout.empty()) {
4545
transposed_shape = snippets::utils::get_planar_vdims(original_shape, layout);
46-
leading_dimension = ov::snippets::utils::get_in_leading_dim(original_shape, layout);
46+
leading_dimension = ov::snippets::utils::get_dim_stride(expr->get_input_port(0));
4747
}
4848

4949
const auto& in_subtensor = in_desc->get_subtensor();

0 commit comments

Comments
 (0)