Skip to content

Commit 6b7bc52

Browse files
authored
[Snippets] Use ov::is_type_any_of for ov::is_type chains (openvinotoolkit#29146)
### Details: - Use ov::is_type_any_of where applicable in snippets common component ### Tickets: - N/A
1 parent 0c3bc3b commit 6b7bc52

18 files changed

+117
-127
lines changed

src/common/snippets/src/lowered/linear_ir.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ void LinearIR::unregister_expression(const ExpressionPtr& expr) {
205205

206206
const auto& node = expr->get_node();
207207
m_node2expression_map.erase(node);
208-
OPENVINO_ASSERT(!ov::is_type<ov::op::v0::Parameter>(node) && !ov::is_type<ov::op::v0::Result>(node),
208+
OPENVINO_ASSERT((!ov::is_type_any_of<ov::op::v0::Parameter, ov::op::v0::Result>(node)),
209209
"unregister_expression mustn't be called for parameter or result expressions");
210210
if (const auto buffer_expr = ov::as_type_ptr<BufferExpression>(expr)) {
211211
const auto& it = std::find(m_buffer_expressions.cbegin(), m_buffer_expressions.cend(), buffer_expr);

src/common/snippets/src/lowered/pass/assign_registers.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ AssignRegisters::RegMap AssignRegisters::assign_regs_manually(const LinearIR& li
5151
for (const auto& pd : buffer->get_output_port_descriptors())
5252
all_equal &= pd->get_reg() == out_reg;
5353
OPENVINO_ASSERT(all_equal, "Buffer must have same register on all inputs and outputs");
54-
} else if (ov::is_type<op::HorizonMax>(op) || ov::is_type<op::HorizonSum>(op)) {
54+
} else if (ov::is_type_any_of<op::HorizonMax, op::HorizonSum>(op)) {
5555
// Only in ReduceDecomposition Reduce ops use HorizonMax/HorizonSum and VectorBuffer.
5656
// We should manually set the one vector register for VectorBuffer and Max/Sum output to simulate a accumulator
5757
// TODO [96351]: We should rewrite accumulator pattern using another way

src/common/snippets/src/lowered/pass/fuse_loops.cpp

+3-8
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,7 @@ bool FuseLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l
190190
for (auto expr_it = begin; expr_it != end; expr_it++) {
191191
const auto expr = *expr_it;
192192
const auto& node = expr->get_node();
193-
if (ov::is_type<ov::op::v0::Parameter>(node) ||
194-
ov::is_type<ov::op::v0::Constant>(node) ||
195-
ov::is_type<ov::op::v0::Result>(node))
193+
if (ov::is_type_any_of<ov::op::v0::Parameter, ov::op::v0::Constant, ov::op::v0::Result>(node))
196194
continue;
197195

198196
// Outer Loop ----> Inner Loop
@@ -224,9 +222,7 @@ bool FuseLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l
224222
const auto parent_expr_output = *input_port.get_expr_port()->get_connected_ports().begin();
225223
const auto& parent_expr = parent_expr_output.get_expr();
226224
const auto parent = parent_expr->get_node();
227-
if (ov::is_type<ov::op::v0::Constant>(parent) ||
228-
ov::is_type<ov::op::v0::Parameter>(parent) ||
229-
ov::is_type<op::Buffer>(parent)) {
225+
if (ov::is_type_any_of<ov::op::v0::Constant, ov::op::v0::Parameter, op::Buffer>(parent)) {
230226
continue;
231227
}
232228

@@ -270,8 +266,7 @@ bool FuseLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l
270266
for (const auto& consumer_expr_input : consumer_exprs_inputs) {
271267
const auto& consumer_expr = consumer_expr_input.get_expr();
272268
const auto consumer = consumer_expr->get_node();
273-
if (ov::is_type<ov::op::v0::Result>(consumer) ||
274-
ov::is_type<op::Buffer>(consumer)) {
269+
if (ov::is_type_any_of<ov::op::v0::Result, op::Buffer>(consumer)) {
275270
continue;
276271
}
277272

src/common/snippets/src/lowered/pass/init_live_ranges.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ inline bool pass_through_expr(const ExpressionPtr& expr) {
1717
const auto& node = expr->get_node();
1818
return op::Subgraph::is_shape_infer_op(node)
1919
#ifdef SNIPPETS_DEBUG_CAPS
20-
|| ov::is_type<op::PerfCountBeginBase>(node)
21-
|| ov::is_type<op::PerfCountEndBase>(node)
20+
|| ov::is_type_any_of<op::PerfCountBeginBase, op::PerfCountEndBase>(node)
2221
#endif
2322
|| ov::is_type<BufferExpression>(expr);
2423
}

src/common/snippets/src/lowered/pass/insert_broadcastmove.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ bool InsertBroadcastMove::is_broadcasting_needed(const std::shared_ptr<ov::Node>
2727
// - VectorBuffer has scalar output shape to avoid broadcast conflicts and manually shape insertion.
2828
// - Fill can be inserted only after VectorBuffer, and should be ignored as well.
2929
return !utils::is_scalar_constant(n) &&
30-
!ov::is_type<ov::snippets::op::VectorBuffer>(n) &&
31-
!ov::is_type<ov::snippets::op::Fill>(n);
30+
!ov::is_type_any_of<ov::snippets::op::VectorBuffer, ov::snippets::op::Fill>(n);
3231
}
3332

3433
std::vector<size_t> InsertBroadcastMove::get_last_dims(const ExpressionPtr& expr) {

src/common/snippets/src/lowered/pass/insert_buffers.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ void InsertBuffers::insertion(LinearIR& linear_ir,
9292
}
9393
const auto& parent_port = parent_expr_output.get_index();
9494
const auto& parent = parent_expr->get_node();
95-
if (ov::is_type<op::Buffer>(parent) ||
96-
ov::is_type<op::VectorBuffer>(parent) ||
97-
ov::is_type<ov::op::v0::Parameter>(parent) ||
98-
ov::is_type<ov::op::v0::Constant>(parent) ||
99-
is_type<op::RankNormalization>(parent))
95+
if (ov::is_type_any_of<op::Buffer,
96+
op::VectorBuffer,
97+
ov::op::v0::Parameter,
98+
ov::op::v0::Constant,
99+
op::RankNormalization>(parent))
100100
continue;
101101

102102
// Each MemoryAccess op needs Buffer

src/common/snippets/src/lowered/pass/insert_loops.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ bool InsertLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin,
5454
for (auto expr_it = begin; expr_it != end; expr_it++) {
5555
const auto expr = *expr_it;
5656
const auto& node = expr->get_node();
57-
if (ov::is_type<op::LoopBase>(node) ||
58-
ov::is_type<ov::op::v0::Parameter>(node) ||
59-
ov::is_type<ov::op::v0::Result>(node))
57+
if (ov::is_type_any_of<op::LoopBase, ov::op::v0::Parameter, ov::op::v0::Result>(node))
6058
continue;
6159

6260
// Outer Loop ----> Inner Loop

src/common/snippets/src/lowered/pass/mark_invariant_shape_path.cpp

+8-10
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,18 @@ static bool is_shape_broadcastable_op(const ExpressionPtr& expr) {
2828

2929
static bool is_not_affecting_op(const ExpressionPtr& expr) {
3030
const auto& node = expr->get_node();
31-
return ov::is_type<ov::snippets::op::HorizonMax>(node) ||
32-
ov::is_type<ov::snippets::op::HorizonSum>(node) ||
33-
ov::is_type<ov::snippets::op::ReduceMax>(node) ||
34-
ov::is_type<ov::snippets::op::ReduceSum>(node) ||
35-
ov::is_type<ov::snippets::op::VectorBuffer>(node) ||
36-
ov::is_type<ov::snippets::op::BroadcastMove>(node) ||
37-
ov::is_type<ov::snippets::op::Scalar>(node);
31+
return ov::is_type_any_of<ov::snippets::op::HorizonMax,
32+
ov::snippets::op::HorizonSum,
33+
ov::snippets::op::ReduceMax,
34+
ov::snippets::op::ReduceSum,
35+
ov::snippets::op::VectorBuffer,
36+
ov::snippets::op::BroadcastMove,
37+
ov::snippets::op::Scalar>(node);
3838
}
3939

4040
static bool is_affecting_op(const ExpressionPtr& expr) {
4141
const auto& node = expr->get_node();
42-
return ov::is_type<ov::snippets::op::Brgemm>(node) ||
43-
ov::is_type<ov::snippets::op::Reshape>(node) ||
44-
ov::is_type<ov::snippets::op::LoadReorder>(node);
42+
return ov::is_type_any_of<ov::snippets::op::Brgemm, ov::snippets::op::Reshape, ov::snippets::op::LoadReorder>(node);
4543
}
4644
} // namespace
4745

src/common/snippets/src/lowered/pass/mark_loops.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ bool MarkLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l
2424

2525
// Parameters, Results or Constants are ignored. They can't be in Loops
2626
auto is_loop_outside_op = [](const std::shared_ptr<ov::Node>& node) {
27-
return ov::is_type<ov::op::v0::Result>(node) ||
28-
ov::is_type<ov::op::v0::Constant>(node) ||
29-
ov::is_type<ov::op::v0::Parameter>(node) ||
30-
ov::is_type<op::RankNormalization>(node) ||
31-
ov::is_type<op::Reshape>(node);
27+
return ov::is_type_any_of<ov::op::v0::Result,
28+
ov::op::v0::Constant,
29+
ov::op::v0::Parameter,
30+
op::RankNormalization,
31+
op::Reshape>(node);
3232
};
3333

3434
auto are_conflicted = [](const ExpressionPort& lhs, const ExpressionPort& rhs) {

src/common/snippets/src/lowered/pass/propagate_subtensors.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ void propagate_updated_subtensor_through_loop(const LinearIR& linear_ir,
136136
expr_it = inner_end;
137137
continue;
138138
}
139-
if ((ov::is_type<snippets::op::BroadcastMove>(expr_it->get()->get_node()) ||
140-
ov::is_type<snippets::op::BroadcastLoad>(expr_it->get()->get_node())) &&
139+
if ((ov::is_type_any_of<snippets::op::BroadcastMove, snippets::op::BroadcastLoad>(
140+
expr_it->get()->get_node())) &&
141141
loop_by_last_dim) {
142142
// WA: we have to break subtensor propagation if we try to propagate new last dim through Broadcast nodes
143143
// which broadcast last dim in original dimension value anyway

src/common/snippets/src/lowered/pass/validate.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ bool Validate::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lo
155155
}
156156
bool bypass_output_size_check =
157157
#ifdef SNIPPETS_DEBUG_CAPS
158-
ov::is_type<snippets::op::PerfCountBegin>(node) || ov::is_type<snippets::op::PerfCountEnd>(node) ||
158+
ov::is_type_any_of<snippets::op::PerfCountBegin, snippets::op::PerfCountEnd>(node) ||
159159
#endif // SNIPPETS_DEBUG_CAPS
160-
ov::is_type<op::LoopEnd>(node) || ov::is_type<ov::op::v0::Result>(node);
160+
ov::is_type_any_of<op::LoopEnd, ov::op::v0::Result>(node);
161161

162162
OPENVINO_ASSERT(expr->get_output_count() == node->get_output_size() || bypass_output_size_check,
163163
"Incorrect count of output port descriptors!");

src/common/snippets/src/op/subgraph.cpp

+14-14
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,16 @@ void Subgraph::set_virtual_port_count(const size_t count) {
8484
}
8585

8686
auto Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool {
87-
return ov::is_type<ov::op::v1::Transpose>(op) ||
88-
ov::is_type<ov::op::v1::Softmax>(op) ||
89-
ov::is_type<ov::op::v8::Softmax>(op) ||
90-
ov::is_type<ov::op::v0::MatMul>(op) ||
91-
ov::is_type<ov::op::v1::Broadcast>(op) || // Broadcast is domain sensetive op because the output shape depends on
92-
ov::is_type<ov::op::v3::Broadcast>(op) || // the both input and broadcast shapes (the both - are inputs of op). Note: is used only in MHA pattern
93-
ov::is_type<ov::op::v12::GroupNormalization>(op) ||
94-
ov::is_type<op::Reshape>(op);
87+
// Broadcast is domain sensetive op because the output shape depends on
88+
// the both input and broadcast shapes (the both - are inputs of op). Note: is used only in MHA pattern
89+
return ov::is_type_any_of<ov::op::v1::Transpose,
90+
ov::op::v1::Softmax,
91+
ov::op::v8::Softmax,
92+
ov::op::v0::MatMul,
93+
ov::op::v1::Broadcast,
94+
ov::op::v3::Broadcast,
95+
ov::op::v12::GroupNormalization,
96+
op::Reshape>(op);
9597
}
9698

9799
auto Subgraph::is_shape_infer_op(const std::shared_ptr<ov::Node>& op) -> bool {
@@ -104,7 +106,7 @@ void Subgraph::init_config() {
104106
for (const auto& op : ops) {
105107
update(config.m_is_quantized, ov::is_type<ov::op::v0::FakeQuantize>(op));
106108
update(config.m_has_domain_sensitive_ops, is_domain_sensitive_op(op));
107-
update(config.m_has_broadcast_sensitive_ops, ov::is_type<ov::op::v12::GroupNormalization>(op) || ov::is_type<op::Reshape>(op));
109+
update(config.m_has_broadcast_sensitive_ops, ov::is_type_any_of<ov::op::v12::GroupNormalization, op::Reshape>(op));
108110
}
109111
}
110112

@@ -140,7 +142,7 @@ auto Subgraph::get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t {
140142
if (are_prev_or_next_ops) {
141143
push_prc_size(transpose->get_element_type().size());
142144
}
143-
} else if (ov::is_type<ov::op::v1::Softmax>(op) || ov::is_type<ov::op::v8::Softmax>(op)) {
145+
} else if (ov::is_type_any_of<ov::op::v1::Softmax, ov::op::v8::Softmax>(op)) {
144146
// Softmax always uses 2 FP32 Buffers after decomposition.
145147
// They are inplace and the same, so we can push precision size only once
146148
push_prc_size(ov::element::f32.size());
@@ -283,10 +285,8 @@ void Subgraph::fill_empty_output_names(const Output<Node>& target_output_node, c
283285
}
284286

285287
auto Subgraph::constant_input_should_be_inside_body(const std::shared_ptr<ov::Node>& node) -> bool {
286-
return ov::is_type<ov::op::v1::Transpose>(node) ||
287-
ov::is_type<ov::op::v1::Broadcast>(node) ||
288-
ov::is_type<ov::op::v3::Broadcast>(node) ||
289-
ov::is_type<ov::op::v1::Reshape>(node);
288+
return ov::is_type_any_of<ov::op::v1::Transpose, ov::op::v1::Broadcast, ov::op::v3::Broadcast, ov::op::v1::Reshape>(
289+
node);
290290
}
291291

292292
bool Subgraph::check_broadcast(const std::shared_ptr<const ov::Node>& node) noexcept {

src/common/snippets/src/pass/collapse_subgraph.cpp

+47-47
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
5050
if (transpose) {
5151
const auto parent = transpose->get_input_node_shared_ptr(0);
5252
const auto child = transpose->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
53-
auto is_brgemm_case = ov::is_type<opset1::MatMul>(parent) || ov::is_type<opset1::MatMul>(child);
53+
auto is_brgemm_case = ov::is_type_any_of<opset1::MatMul, opset1::MatMul>(child);
5454
auto decomposition_case = true;
5555
// Check for Transpose parent is MatMul inside Subgraph
5656
if (const auto subgraph = ov::as_type_ptr<const op::Subgraph>(parent)) {
@@ -81,51 +81,51 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
8181
return ov::is_type<ov::op::v1::Select>(n);
8282
};
8383

84-
auto is_supported_binary_eltwise_op = [](const std::shared_ptr<const Node> &n) -> bool {
85-
return ov::is_type<ov::op::v1::Add>(n)
86-
|| ov::is_type<ov::op::v1::Divide>(n)
87-
|| ov::is_type<ov::op::v1::Equal>(n)
88-
|| ov::is_type<ov::op::v1::FloorMod>(n)
89-
|| ov::is_type<ov::op::v1::Greater>(n)
90-
|| ov::is_type<ov::op::v1::GreaterEqual>(n)
91-
|| ov::is_type<ov::op::v1::Less>(n)
92-
|| ov::is_type<ov::op::v1::LessEqual>(n)
93-
|| ov::is_type<ov::op::v1::LogicalAnd>(n)
94-
|| ov::is_type<ov::op::v1::LogicalOr>(n)
95-
|| ov::is_type<ov::op::v1::LogicalXor>(n)
96-
|| ov::is_type<ov::op::v1::Maximum>(n)
97-
|| ov::is_type<ov::op::v1::Minimum>(n)
98-
|| ov::is_type<ov::op::v1::Mod>(n)
99-
|| ov::is_type<ov::op::v1::Multiply>(n)
100-
|| ov::is_type<ov::op::v1::NotEqual>(n)
101-
|| ov::is_type<ov::op::v0::PRelu>(n)
102-
|| ov::is_type<ov::op::v1::Power>(n)
103-
|| ov::is_type<ov::op::v0::SquaredDifference>(n)
104-
|| ov::is_type<ov::op::v1::Subtract>(n)
105-
|| ov::is_type<ov::op::v0::Xor>(n)
106-
|| ov::is_type<ov::op::v0::Convert>(n);
84+
auto is_supported_binary_eltwise_op = [](const std::shared_ptr<const Node>& n) -> bool {
85+
return ov::is_type_any_of<ov::op::v1::Add,
86+
ov::op::v1::Divide,
87+
ov::op::v1::Equal,
88+
ov::op::v1::FloorMod,
89+
ov::op::v1::Greater,
90+
ov::op::v1::GreaterEqual,
91+
ov::op::v1::Less,
92+
ov::op::v1::LessEqual,
93+
ov::op::v1::LogicalAnd,
94+
ov::op::v1::LogicalOr,
95+
ov::op::v1::LogicalXor,
96+
ov::op::v1::Maximum,
97+
ov::op::v1::Minimum,
98+
ov::op::v1::Mod,
99+
ov::op::v1::Multiply,
100+
ov::op::v1::NotEqual,
101+
ov::op::v0::PRelu,
102+
ov::op::v1::Power,
103+
ov::op::v0::SquaredDifference,
104+
ov::op::v1::Subtract,
105+
ov::op::v0::Xor,
106+
ov::op::v0::Convert>(n);
107107
};
108108

109-
auto is_supported_unary_eltwise_op = [](const std::shared_ptr<const Node> &n) -> bool {
110-
return ov::is_type<ov::op::v0::Abs>(n)
111-
|| ov::is_type<ov::op::v0::Clamp>(n)
112-
|| ov::is_type<ov::op::v0::Floor>(n)
113-
|| ov::is_type<ov::op::v0::Ceiling>(n)
114-
|| ov::is_type<ov::op::v0::Elu>(n)
115-
|| ov::is_type<ov::op::v0::Erf>(n)
116-
|| ov::is_type<ov::op::v0::Exp>(n)
117-
|| ov::is_type<ov::op::v1::LogicalNot>(n)
118-
|| ov::is_type<ov::op::v4::Mish>(n)
119-
|| ov::is_type<ov::op::v0::Negative>(n)
120-
|| ov::is_type<ov::op::v0::Relu>(n)
121-
|| ov::is_type<ov::op::v5::Round>(n)
122-
|| ov::is_type<ov::op::v0::Sigmoid>(n)
123-
|| ov::is_type<ov::op::v0::Sqrt>(n)
124-
|| ov::is_type<ov::op::v0::Tanh>(n)
125-
|| ov::is_type<ov::op::v0::Gelu>(n)
126-
|| ov::is_type<ov::op::v7::Gelu>(n)
127-
|| ov::is_type<ov::op::v4::Swish>(n)
128-
|| ov::is_type<ov::op::v4::HSwish>(n);
109+
auto is_supported_unary_eltwise_op = [](const std::shared_ptr<const Node>& n) -> bool {
110+
return ov::is_type_any_of<ov::op::v0::Abs,
111+
ov::op::v0::Clamp,
112+
ov::op::v0::Floor,
113+
ov::op::v0::Ceiling,
114+
ov::op::v0::Elu,
115+
ov::op::v0::Erf,
116+
ov::op::v0::Exp,
117+
ov::op::v1::LogicalNot,
118+
ov::op::v4::Mish,
119+
ov::op::v0::Negative,
120+
ov::op::v0::Relu,
121+
ov::op::v5::Round,
122+
ov::op::v0::Sigmoid,
123+
ov::op::v0::Sqrt,
124+
ov::op::v0::Tanh,
125+
ov::op::v0::Gelu,
126+
ov::op::v7::Gelu,
127+
ov::op::v4::Swish,
128+
ov::op::v4::HSwish>(n);
129129
};
130130

131131
auto is_supported_softmax = [](const std::shared_ptr<const Node> &n) -> bool {
@@ -156,7 +156,7 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
156156
};
157157

158158
auto is_supported_reduce_op = [](const std::shared_ptr<const Node> &n) -> bool {
159-
if (ov::is_type<const ov::op::v1::ReduceMax>(n) || ov::is_type<const ov::op::v1::ReduceSum>(n)) {
159+
if (ov::is_type_any_of<const ov::op::v1::ReduceMax, const ov::op::v1::ReduceSum>(n)) {
160160
const auto& reduce_base = ov::as_type_ptr<const ov::op::util::ArithmeticReductionKeepDims>(n);
161161
const auto& axis_constant = ov::as_type_ptr<const ov::op::v0::Constant>(n->get_input_node_shared_ptr(1));
162162
const auto rank = n->get_input_partial_shape(0).rank();
@@ -229,8 +229,8 @@ TokenizeSnippets::TokenizeSnippets(const SnippetsTokenization::Config& config) {
229229
// This is a temporary solution. Either modify SnippetsMarkSkipped
230230
// or align this with the custom MHA tokenization pass.
231231
return (GetSnippetsNodeType(n) != SnippetsNodeType::SkippedByPlugin ||
232-
ov::is_type<ov::op::v0::MatMul>(n) || ov::is_type<ov::op::v1::Transpose>(n))
233-
&& AppropriateForSubgraph(n);
232+
ov::is_type_any_of<ov::op::v0::MatMul, ov::op::v1::Transpose>(n)) &&
233+
AppropriateForSubgraph(n);
234234
});
235235
ov::graph_rewrite_callback callback = [=](ov::pass::pattern::Matcher &m) -> bool {
236236
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::CreateSubgraph_callback")

0 commit comments

Comments
 (0)