@@ -142,13 +142,13 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
142
142
const auto & group_norm_gamma = pattern_map.at (group_norm_gamma_m);
143
143
if (group_norm_gamma.get_element_type () != T)
144
144
return false ;
145
- if (group_norm_gamma.get_partial_shape (). rank (). get_length () != num_channels)
145
+ if (ov::shape_size ( group_norm_gamma.get_shape ()) != static_cast < size_t >( num_channels) )
146
146
return false ;
147
147
148
148
const auto & group_norm_beta = pattern_map.at (group_norm_beta_m);
149
149
if (group_norm_beta.get_element_type () != T)
150
150
return false ;
151
- if (group_norm_beta.get_partial_shape (). rank (). get_length () != num_channels)
151
+ if (ov::shape_size ( group_norm_beta.get_shape ()) != static_cast < size_t >( num_channels) )
152
152
return false ;
153
153
154
154
ov::NodeVector nodes;
@@ -180,7 +180,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
180
180
const auto & instance_norm_beta = pattern_map.at (instance_norm_beta_m);
181
181
if (instance_norm_beta.get_element_type () != T)
182
182
return false ;
183
- if (instance_norm_beta.get_partial_shape (). rank (). get_length () != num_groups)
183
+ if (ov::shape_size ( instance_norm_beta.get_shape ()) != static_cast < size_t >( num_groups) )
184
184
return false ;
185
185
186
186
// ensure that instance_norm_beta will have shape compatible
@@ -219,7 +219,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
219
219
const auto & instance_norm_gamma = pattern_map.at (instance_norm_gamma_m);
220
220
if (instance_norm_gamma.get_element_type () != T)
221
221
return false ;
222
- if (instance_norm_gamma.get_partial_shape (). rank (). get_length () != num_groups)
222
+ if (ov::shape_size ( instance_norm_gamma.get_shape ()) != static_cast < size_t >( num_groups) )
223
223
return false ;
224
224
225
225
// ensure that instance_norm_gamma will have shape compatible
0 commit comments