@@ -50,7 +50,7 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
50
50
if (transpose) {
51
51
const auto parent = transpose->get_input_node_shared_ptr (0 );
52
52
const auto child = transpose->get_output_target_inputs (0 ).begin ()->get_node ()->shared_from_this ();
53
- auto is_brgemm_case = ov::is_type <opset1::MatMul>(parent) || ov::is_type< opset1::MatMul>(child);
53
+ auto is_brgemm_case = ov::is_type_any_of <opset1::MatMul, opset1::MatMul>(child);
54
54
auto decomposition_case = true ;
55
55
// Check for Transpose parent is MatMul inside Subgraph
56
56
if (const auto subgraph = ov::as_type_ptr<const op::Subgraph>(parent)) {
@@ -81,51 +81,51 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
81
81
return ov::is_type<ov::op::v1::Select>(n);
82
82
};
83
83
84
- auto is_supported_binary_eltwise_op = [](const std::shared_ptr<const Node> & n) -> bool {
85
- return ov::is_type <ov::op::v1::Add>(n)
86
- || ov::is_type<ov:: op::v1::Divide>(n)
87
- || ov::is_type<ov:: op::v1::Equal>(n)
88
- || ov::is_type<ov:: op::v1::FloorMod>(n)
89
- || ov::is_type<ov:: op::v1::Greater>(n)
90
- || ov::is_type<ov:: op::v1::GreaterEqual>(n)
91
- || ov::is_type<ov:: op::v1::Less>(n)
92
- || ov::is_type<ov:: op::v1::LessEqual>(n)
93
- || ov::is_type<ov:: op::v1::LogicalAnd>(n)
94
- || ov::is_type<ov:: op::v1::LogicalOr>(n)
95
- || ov::is_type<ov:: op::v1::LogicalXor>(n)
96
- || ov::is_type<ov:: op::v1::Maximum>(n)
97
- || ov::is_type<ov:: op::v1::Minimum>(n)
98
- || ov::is_type<ov:: op::v1::Mod>(n)
99
- || ov::is_type<ov:: op::v1::Multiply>(n)
100
- || ov::is_type<ov:: op::v1::NotEqual>(n)
101
- || ov::is_type<ov:: op::v0::PRelu>(n)
102
- || ov::is_type<ov:: op::v1::Power>(n)
103
- || ov::is_type<ov:: op::v0::SquaredDifference>(n)
104
- || ov::is_type<ov:: op::v1::Subtract>(n)
105
- || ov::is_type<ov:: op::v0::Xor>(n)
106
- || ov::is_type< ov::op::v0::Convert>(n);
84
+ auto is_supported_binary_eltwise_op = [](const std::shared_ptr<const Node>& n) -> bool {
85
+ return ov::is_type_any_of <ov::op::v1::Add,
86
+ ov::op::v1::Divide,
87
+ ov::op::v1::Equal,
88
+ ov::op::v1::FloorMod,
89
+ ov::op::v1::Greater,
90
+ ov::op::v1::GreaterEqual,
91
+ ov::op::v1::Less,
92
+ ov::op::v1::LessEqual,
93
+ ov::op::v1::LogicalAnd,
94
+ ov::op::v1::LogicalOr,
95
+ ov::op::v1::LogicalXor,
96
+ ov::op::v1::Maximum,
97
+ ov::op::v1::Minimum,
98
+ ov::op::v1::Mod,
99
+ ov::op::v1::Multiply,
100
+ ov::op::v1::NotEqual,
101
+ ov::op::v0::PRelu,
102
+ ov::op::v1::Power,
103
+ ov::op::v0::SquaredDifference,
104
+ ov::op::v1::Subtract,
105
+ ov::op::v0::Xor,
106
+ ov::op::v0::Convert>(n);
107
107
};
108
108
109
- auto is_supported_unary_eltwise_op = [](const std::shared_ptr<const Node> & n) -> bool {
110
- return ov::is_type <ov::op::v0::Abs>(n)
111
- || ov::is_type<ov:: op::v0::Clamp>(n)
112
- || ov::is_type<ov:: op::v0::Floor>(n)
113
- || ov::is_type<ov:: op::v0::Ceiling>(n)
114
- || ov::is_type<ov:: op::v0::Elu>(n)
115
- || ov::is_type<ov:: op::v0::Erf>(n)
116
- || ov::is_type<ov:: op::v0::Exp>(n)
117
- || ov::is_type<ov:: op::v1::LogicalNot>(n)
118
- || ov::is_type<ov:: op::v4::Mish>(n)
119
- || ov::is_type<ov:: op::v0::Negative>(n)
120
- || ov::is_type<ov:: op::v0::Relu>(n)
121
- || ov::is_type<ov:: op::v5::Round>(n)
122
- || ov::is_type<ov:: op::v0::Sigmoid>(n)
123
- || ov::is_type<ov:: op::v0::Sqrt>(n)
124
- || ov::is_type<ov:: op::v0::Tanh>(n)
125
- || ov::is_type<ov:: op::v0::Gelu>(n)
126
- || ov::is_type<ov:: op::v7::Gelu>(n)
127
- || ov::is_type<ov:: op::v4::Swish>(n)
128
- || ov::is_type< ov::op::v4::HSwish>(n);
109
+ auto is_supported_unary_eltwise_op = [](const std::shared_ptr<const Node>& n) -> bool {
110
+ return ov::is_type_any_of <ov::op::v0::Abs,
111
+ ov::op::v0::Clamp,
112
+ ov::op::v0::Floor,
113
+ ov::op::v0::Ceiling,
114
+ ov::op::v0::Elu,
115
+ ov::op::v0::Erf,
116
+ ov::op::v0::Exp,
117
+ ov::op::v1::LogicalNot,
118
+ ov::op::v4::Mish,
119
+ ov::op::v0::Negative,
120
+ ov::op::v0::Relu,
121
+ ov::op::v5::Round,
122
+ ov::op::v0::Sigmoid,
123
+ ov::op::v0::Sqrt,
124
+ ov::op::v0::Tanh,
125
+ ov::op::v0::Gelu,
126
+ ov::op::v7::Gelu,
127
+ ov::op::v4::Swish,
128
+ ov::op::v4::HSwish>(n);
129
129
};
130
130
131
131
auto is_supported_softmax = [](const std::shared_ptr<const Node> &n) -> bool {
@@ -156,7 +156,7 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
156
156
};
157
157
158
158
auto is_supported_reduce_op = [](const std::shared_ptr<const Node> &n) -> bool {
159
- if (ov::is_type <const ov::op::v1::ReduceMax>(n) || ov::is_type< const ov::op::v1::ReduceSum>(n)) {
159
+ if (ov::is_type_any_of <const ov::op::v1::ReduceMax, const ov::op::v1::ReduceSum>(n)) {
160
160
const auto & reduce_base = ov::as_type_ptr<const ov::op::util::ArithmeticReductionKeepDims>(n);
161
161
const auto & axis_constant = ov::as_type_ptr<const ov::op::v0::Constant>(n->get_input_node_shared_ptr (1 ));
162
162
const auto rank = n->get_input_partial_shape (0 ).rank ();
@@ -229,8 +229,8 @@ TokenizeSnippets::TokenizeSnippets(const SnippetsTokenization::Config& config) {
229
229
// This is a temporary solution. Either modify SnippetsMarkSkipped
230
230
// or align this with the custom MHA tokenization pass.
231
231
return (GetSnippetsNodeType (n) != SnippetsNodeType::SkippedByPlugin ||
232
- ov::is_type <ov::op::v0::MatMul>(n) || ov::is_type<ov:: op::v1::Transpose>(n))
233
- && AppropriateForSubgraph (n);
232
+ ov::is_type_any_of <ov::op::v0::MatMul, ov::op::v1::Transpose>(n)) &&
233
+ AppropriateForSubgraph (n);
234
234
});
235
235
ov::graph_rewrite_callback callback = [=](ov::pass::pattern::Matcher &m) -> bool {
236
236
OV_ITT_SCOPED_TASK (ov::pass::itt::domains::SnippetsTransform, " Snippets::CreateSubgraph_callback" )
0 commit comments