|
4 | 4 |
|
5 | 5 | #include "transformations/common_optimizations/move_eltwise_up_data_movement.hpp"
|
6 | 6 |
|
| 7 | +#include <algorithm> |
7 | 8 | #include <memory>
|
8 |
| -#include <numeric> |
9 | 9 | #include <openvino/opsets/opset8.hpp>
|
10 | 10 |
|
11 | 11 | #include "itt.hpp"
|
12 | 12 | #include "openvino/core/rt_info.hpp"
|
13 | 13 | #include "openvino/core/type.hpp"
|
| 14 | +#include "openvino/op/constant.hpp" |
| 15 | +#include "openvino/op/reshape.hpp" |
| 16 | +#include "openvino/op/squeeze.hpp" |
| 17 | +#include "openvino/op/unsqueeze.hpp" |
14 | 18 | #include "openvino/pass/pattern/op/wrap_type.hpp"
|
15 | 19 | #include "transformations/utils/utils.hpp"
|
16 | 20 |
|
@@ -50,9 +54,9 @@ std::vector<ov::DiscreteTypeInfo> ov::pass::MoveEltwiseUpThroughDataMov::get_def
|
50 | 54 | };
|
51 | 55 | }
|
52 | 56 |
|
53 |
| -ov::pass::MoveEltwiseUpThroughDataMov::MoveEltwiseUpThroughDataMov( |
| 57 | +ov::pass::MoveEltwiseUpThroughDataMovScalar::MoveEltwiseUpThroughDataMovScalar( |
54 | 58 | std::vector<DiscreteTypeInfo> allowed_data_movement_ops) {
|
55 |
| - MATCHER_SCOPE(MoveEltwiseUpThroughDataMov); |
| 59 | + MATCHER_SCOPE(MoveEltwiseUpThroughDataMovScalar); |
56 | 60 | auto eltwise_pattern = ov::pass::pattern::wrap_type<ov::op::util::UnaryElementwiseArithmetic,
|
57 | 61 | ov::op::util::BinaryElementwiseArithmetic,
|
58 | 62 | ov::op::v0::FakeQuantize>(ov::pass::pattern::has_static_rank());
|
@@ -126,3 +130,102 @@ ov::pass::MoveEltwiseUpThroughDataMov::MoveEltwiseUpThroughDataMov(
|
126 | 130 | auto m = std::make_shared<ov::pass::pattern::Matcher>(eltwise_pattern, matcher_name);
|
127 | 131 | register_matcher(m, callback);
|
128 | 132 | }
|
| 133 | + |
| 134 | +ov::pass::MoveEltwiseUpThroughDataMovPerChannel::MoveEltwiseUpThroughDataMovPerChannel() { |
| 135 | + MATCHER_SCOPE(MoveEltwiseUpThroughDataMovPerChannel); |
| 136 | + |
| 137 | + auto const_predicate = [](const ov::Output<ov::Node>& output) { |
| 138 | + auto constant_op = std::dynamic_pointer_cast<ov::opset8::Constant>(output.get_node_shared_ptr()); |
| 139 | + if (!constant_op) |
| 140 | + return false; |
| 141 | + |
| 142 | + if (output.get_target_inputs().size() != 1) |
| 143 | + return false; |
| 144 | + |
| 145 | + const auto& shape = constant_op->get_shape(); |
| 146 | + return std::count_if(shape.begin(), shape.end(), [](size_t v) { |
| 147 | + return v > 1; |
| 148 | + }) == 1; |
| 149 | + }; |
| 150 | + |
| 151 | + auto eltw_predicate = [](const ov::Output<ov::Node>& output) { |
| 152 | + if (output.get_target_inputs().size() != 1) |
| 153 | + return false; |
| 154 | + |
| 155 | + auto node = output.get_node(); |
| 156 | + |
| 157 | + if (node->get_output_partial_shape(0).rank().is_dynamic()) |
| 158 | + return false; |
| 159 | + |
| 160 | + const size_t const_idx = ov::is_type<ov::op::v0::Constant>(node->get_input_node_ptr(0)) ? 0 : 1; |
| 161 | + const size_t data_flow_idx = (const_idx + 1) % 2; |
| 162 | + |
| 163 | + if (node->get_input_partial_shape(data_flow_idx).size() < node->get_input_partial_shape(const_idx).size()) |
| 164 | + return false; |
| 165 | + |
| 166 | + return true; |
| 167 | + }; |
| 168 | + |
| 169 | + auto eltw_data_flow_in = |
| 170 | + ov::pass::pattern::wrap_type<ov::op::v1::Reshape, ov::op::v0::Squeeze, ov::op::v0::Unsqueeze>(); |
| 171 | + auto eltw_const_in = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(const_predicate); |
| 172 | + auto eltwise_pattern = |
| 173 | + ov::pass::pattern::wrap_type<ov::op::util::BinaryElementwiseArithmetic>({eltw_data_flow_in, eltw_const_in}, |
| 174 | + eltw_predicate); |
| 175 | + |
| 176 | + ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { |
| 177 | + const auto& pattern_map = m.get_pattern_value_map(); |
| 178 | + |
| 179 | + auto eltwise = pattern_map.at(eltwise_pattern).get_node_shared_ptr(); |
| 180 | + if (transformation_callback(eltwise)) { |
| 181 | + return false; |
| 182 | + } |
| 183 | + |
| 184 | + const size_t const_idx = ov::is_type<ov::op::v0::Constant>(eltwise->get_input_node_ptr(0)) ? 0 : 1; |
| 185 | + const size_t data_flow_idx = (const_idx + 1) % 2; |
| 186 | + |
| 187 | + auto const_shape = eltwise->get_input_shape(const_idx); |
| 188 | + size_t channel_idx = 0; |
| 189 | + size_t channel_val = 0; |
| 190 | + for (size_t i = 0; i < const_shape.size(); i++) { |
| 191 | + if (const_shape[i] > 1) { |
| 192 | + channel_idx = i; |
| 193 | + channel_val = const_shape[i]; |
| 194 | + } |
| 195 | + } |
| 196 | + |
| 197 | + auto parent = eltwise->get_input_node_shared_ptr(data_flow_idx); |
| 198 | + const auto& parent_in_pshape = parent->get_input_partial_shape(0); |
| 199 | + auto parent_in_channel_dim = |
| 200 | + parent_in_pshape.size() <= channel_idx ? ov::Dimension(1) : parent_in_pshape[channel_idx]; |
| 201 | + auto parent_out_channel_dim = parent->get_output_partial_shape(0)[channel_idx]; |
| 202 | + if (parent_in_channel_dim.is_dynamic() || parent_in_channel_dim != channel_val || |
| 203 | + parent_out_channel_dim.is_dynamic() || parent_out_channel_dim != channel_val) |
| 204 | + return false; |
| 205 | + |
| 206 | + auto new_shape = ov::Shape(parent->get_input_partial_shape(0).size(), 1); |
| 207 | + |
| 208 | + new_shape[channel_idx] = const_shape[channel_idx]; |
| 209 | + auto old_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(eltwise->get_input_node_shared_ptr(const_idx)); |
| 210 | + auto new_const = std::make_shared<ov::op::v0::Constant>(*old_const, new_shape); |
| 211 | + ov::replace_node_update_name(old_const, new_const); |
| 212 | + ov::replace_output_update_name(eltwise->output(0), eltwise->input_value(data_flow_idx)); |
| 213 | + |
| 214 | + ov::OutputVector eltwise_inputs = eltwise->input_values(); |
| 215 | + eltwise_inputs[data_flow_idx] = parent->input_value(0); |
| 216 | + auto new_eltwise = eltwise->clone_with_new_inputs(eltwise_inputs); |
| 217 | + ov::copy_runtime_info(eltwise, new_eltwise); |
| 218 | + |
| 219 | + ov::OutputVector parent_inputs = parent->input_values(); |
| 220 | + parent_inputs[0] = new_eltwise; |
| 221 | + auto new_parent = parent->clone_with_new_inputs(parent_inputs); |
| 222 | + ov::copy_runtime_info(parent, new_parent); |
| 223 | + new_parent->set_friendly_name(parent->get_friendly_name()); |
| 224 | + |
| 225 | + ov::replace_node(parent, new_parent); |
| 226 | + return true; |
| 227 | + }; |
| 228 | + |
| 229 | + auto m = std::make_shared<ov::pass::pattern::Matcher>(eltwise_pattern, matcher_name); |
| 230 | + register_matcher(m, callback); |
| 231 | +} |
0 commit comments