Skip to content

Commit 875ed9d

Browse files
committed
TPP
1 parent bec3062 commit 875ed9d

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/plugins/intel_cpu/src/transformations/tpp/x64/pass/eltwise_to_eltwise_tpp.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,18 @@ EltwiseToEltwiseTPP::EltwiseToEltwiseTPP() {
4040
ov::is_type<ov::snippets::op::ReduceBase>(node) ? ov::snippets::utils::get_full_dim_value() : 64;
4141
ov::replace_node_update_name(node, tpp_eltwise);
4242
for (size_t i = 0; i < node->get_input_size(); i++) {
43-
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(tpp_eltwise->input(i), {M_block, N_block});
43+
auto subtensor = snippets::VectorDims{M_block, N_block};
44+
if (tpp_eltwise->input(i).get_partial_shape().size() < 2) {
45+
subtensor = snippets::VectorDims{N_block};
46+
}
47+
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(tpp_eltwise->input(i), subtensor);
4448
}
4549

46-
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(tpp_eltwise->output(0), {M_block, N_block});
50+
auto subtensor = snippets::VectorDims{M_block, N_block};
51+
if (tpp_eltwise->output(0).get_partial_shape().size() < 2) {
52+
subtensor = snippets::VectorDims{N_block};
53+
}
54+
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(tpp_eltwise->output(0), subtensor);
4755

4856
return true;
4957
};

src/plugins/intel_cpu/src/transformations/tpp/x64/pass/fuse_tpp_to_equations.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,15 @@ bool FuseTPPToEquations::fuse_from_root(const NodePtr& root, const std::shared_p
8282
kv.second = equation;
8383
replace_nodes(m, {}, node_replace_map);
8484
for (const auto& in : equation->inputs()) {
85-
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(in, root_subtensor);
85+
auto subtensor = root_subtensor;
86+
if (in.get_partial_shape().size() < 2) {
87+
subtensor = {root_subtensor.back()};
88+
}
89+
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(in, subtensor);
90+
}
91+
auto subtensor = root_subtensor;
92+
if (equation->output(0).get_partial_shape().size() < 2) {
93+
subtensor = {root_subtensor.back()};
8694
}
8795
ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(equation->output(0), root_subtensor);
8896
return true;

0 commit comments

Comments
 (0)