@@ -72,11 +72,16 @@ bool isFullyConnected(const std::shared_ptr<const ov::Node>& node) {
72
72
bool SupportsFusingWithConvolution_SumActivation (const std::shared_ptr<const Node>& node) {
73
73
// todo: Do all PReLUs are fused? Not sure about round and softRelu
74
74
// EltwiseRoundHalfToEven, EltwiseRoundHalfAwayFromZero, EltwiseSoftRelu
75
- return ov::is_type<ov::op::v0::Relu>(node) || ov::is_type<ov::op::v0::PRelu>(node) ||
76
- ov::is_type<ov::op::v0::Elu>(node) || ov::is_type<ov::op::v0::Sigmoid>(node) ||
77
- ov::is_type<ov::op::v5::HSigmoid>(node) || ov::is_type<ov::op::v0::Clamp>(node) ||
78
- ov::is_type<ov::op::v4::Swish>(node) || ov::is_type<ov::op::v4::HSwish>(node) ||
79
- ov::is_type<ov::op::v4::Mish>(node) || ov::is_type<ov::op::v5::Round>(node);
75
+ return ov::is_type_any_of<ov::op::v0::Relu,
76
+ ov::op::v0::PRelu,
77
+ ov::op::v0::Elu,
78
+ ov::op::v0::Sigmoid,
79
+ ov::op::v5::HSigmoid,
80
+ ov::op::v0::Clamp,
81
+ ov::op::v4::Swish,
82
+ ov::op::v4::HSwish,
83
+ ov::op::v4::Mish,
84
+ ov::op::v5::Round>(node);
80
85
}
81
86
82
87
bool canBePerformedAsScaleShift (const std::shared_ptr<const Node>& node, const int channelAxis) {
@@ -120,8 +125,7 @@ bool canBePerformedAsScaleShift(const std::shared_ptr<const Node>& node, const i
120
125
121
126
// Prelu and MulAdd are still ignored
122
127
// isConvertablePowerStatic() is ignored
123
- return (ov::is_type<ov::opset1::Add>(node) || ov::is_type<ov::opset1::Multiply>(node) ||
124
- ov::is_type<ov::opset1::Subtract>(node) || ov::is_type<ov::opset1::Divide>(node)) &&
128
+ return ov::is_type_any_of<ov::opset1::Add, ov::opset1::Multiply, ov::opset1::Subtract, ov::opset1::Divide>(node) &&
125
129
isBroadcastableToDataInput ();
126
130
}
127
131
@@ -131,15 +135,18 @@ inline bool canBeMatMulExecutedInInt8(const ov::element::Type& firstType, const
131
135
132
136
bool SupportsFusingWithConvolution_Simple (const std::shared_ptr<const Node>& node,
133
137
const int channelAxis = DEFAULT_AXIS) {
134
- return SupportsFusingWithConvolution_SumActivation (node) || ov::is_type<ov::op::v0::Tanh>(node) ||
135
- ov::is_type<ov::op::v0::Gelu>(node) || ov::is_type<ov::op::v7::Gelu>(node) ||
136
- ov::is_type<ov::op::v0::Abs>(node) || ov::is_type<ov::op::v0::Sqrt>(node) ||
137
- ov::is_type<ov::op::v0::FakeQuantize>(node) || canBePerformedAsScaleShift (node, channelAxis);
138
+ return SupportsFusingWithConvolution_SumActivation (node) ||
139
+ ov::is_type_any_of<ov::op::v0::Tanh,
140
+ ov::op::v0::Gelu,
141
+ ov::op::v7::Gelu,
142
+ ov::op::v0::Abs,
143
+ ov::op::v0::Sqrt,
144
+ ov::op::v0::FakeQuantize>(node) ||
145
+ canBePerformedAsScaleShift (node, channelAxis);
138
146
}
139
147
// Convolution is a special case, since it supports peculiar fusings
140
148
bool isSuitableConvolutionParent (const std::shared_ptr<const Node>& node) {
141
- const bool is_suitable_node =
142
- ov::is_type<ov::op::v1::Convolution>(node) || ov::is_type<ov::op::v1::GroupConvolution>(node);
149
+ const bool is_suitable_node = ov::is_type_any_of<ov::op::v1::Convolution, ov::op::v1::GroupConvolution>(node);
143
150
// has a single output, connected to a single child
144
151
const auto out = node->outputs ();
145
152
const bool has_only_child = (out.size () == 1 ) && (out[0 ].get_target_inputs ().size () == 1 );
@@ -168,14 +175,18 @@ int getChannelAxis(const ov::AxisSet& axes, bool keep_dims) {
168
175
return channelAxis;
169
176
}
170
177
bool isSuitableMiscParent (const std::shared_ptr<const Node>& node) {
171
- const bool is_suitable_node =
172
- ov::is_type<ov::op::v0::MVN>(node) || ov::is_type<ov::op::v6::MVN>(node) ||
173
- ov::is_type<ov::op::v0::NormalizeL2>(node) || ov::is_type<ov::op::v0::Interpolate>(node) ||
174
- ov::is_type<ov::op::v4::Interpolate>(node) || ov::is_type<ov::op::v0::LSTMCell>(node) ||
175
- ov::is_type<ov::op::v4::LSTMCell>(node) || ov::is_type<ov::opset1::ConvolutionBackpropData>(node) ||
176
- ov::is_type<ov::op::util::ArithmeticReductionKeepDims>(node) ||
177
- ov::is_type<ov::opset1::GroupConvolutionBackpropData>(node) || ov::is_type<ov::opset1::AvgPool>(node) ||
178
- ov::is_type<ov::op::v14::AvgPool>(node);
178
+ const bool is_suitable_node = ov::is_type_any_of<ov::op::v0::MVN,
179
+ ov::op::v6::MVN,
180
+ ov::op::v0::NormalizeL2,
181
+ ov::op::v0::Interpolate,
182
+ ov::op::v4::Interpolate,
183
+ ov::op::v0::LSTMCell,
184
+ ov::op::v4::LSTMCell,
185
+ ov::opset1::ConvolutionBackpropData,
186
+ ov::op::util::ArithmeticReductionKeepDims,
187
+ ov::opset1::GroupConvolutionBackpropData,
188
+ ov::opset1::AvgPool,
189
+ ov::op::v14::AvgPool>(node);
179
190
// has a single output, connected to a single child
180
191
const auto out = node->outputs ();
181
192
const bool has_only_child = (out.size () == 1 ) && (out[0 ].get_target_inputs ().size () == 1 );
@@ -307,9 +318,11 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node>& node,
307
318
308
319
// MatMul specific checks from ::canFuse()
309
320
if (one_of (updatedChainType, NodeFusingType::FusedWithMatMul, NodeFusingType::FusedWithMatMulI8)) {
310
- const auto is_binary_eltwise = ov::is_type<ov::op::v1::Add>(node) || ov::is_type<ov::op::v1::Multiply>(node) ||
311
- ov::is_type<ov::op::v1::Subtract>(node) ||
312
- ov::is_type<ov::op::v1::Divide>(node) || ov::is_type<ov::op::v0::PRelu>(node);
321
+ const auto is_binary_eltwise = ov::is_type_any_of<ov::op::v1::Add,
322
+ ov::op::v1::Multiply,
323
+ ov::op::v1::Subtract,
324
+ ov::op::v1::Divide,
325
+ ov::op::v0::PRelu>(node);
313
326
const auto rank = node->get_output_partial_shape (0 ).rank ();
314
327
if (dnnl::impl::cpu::x64::mayiuse (dnnl::impl::cpu::x64::avx512_core) && rank.is_static () && is_binary_eltwise) {
315
328
const auto const1 = ov::is_type<ov::op::v0::Constant>(node->get_input_node_shared_ptr (0 ));
@@ -490,8 +503,7 @@ bool isSuitableConvert(const std::shared_ptr<const Node>& node) {
490
503
}
491
504
492
505
auto is_skipped_op (const std::shared_ptr<ov::Node>& op) -> bool {
493
- return ov::is_type<ov::op::v0::Constant>(op) || ov::is_type<ov::op::v0::Parameter>(op) ||
494
- ov::is_type<ov::op::v0::Result>(op);
506
+ return ov::is_type_any_of<ov::op::v0::Constant, ov::op::v0::Parameter, ov::op::v0::Result>(op);
495
507
}
496
508
} // namespace
497
509
0 commit comments