Skip to content

Commit 4f95a72

Browse files
revert some changes
1 parent bc911a5 commit 4f95a72

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,13 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
142142
const auto& group_norm_gamma = pattern_map.at(group_norm_gamma_m);
143143
if (group_norm_gamma.get_element_type() != T)
144144
return false;
145-
if (group_norm_gamma.get_partial_shape().rank().get_length() != num_channels)
145+
if (ov::shape_size(group_norm_gamma.get_shape()) != static_cast<size_t>(num_channels))
146146
return false;
147147

148148
const auto& group_norm_beta = pattern_map.at(group_norm_beta_m);
149149
if (group_norm_beta.get_element_type() != T)
150150
return false;
151-
if (group_norm_beta.get_partial_shape().rank().get_length() != num_channels)
151+
if (ov::shape_size(group_norm_beta.get_shape()) != static_cast<size_t>(num_channels))
152152
return false;
153153

154154
ov::NodeVector nodes;
@@ -180,7 +180,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
180180
const auto& instance_norm_beta = pattern_map.at(instance_norm_beta_m);
181181
if (instance_norm_beta.get_element_type() != T)
182182
return false;
183-
if (instance_norm_beta.get_partial_shape().rank().get_length() != num_groups)
183+
if (ov::shape_size(instance_norm_beta.get_shape()) != static_cast<size_t>(num_groups))
184184
return false;
185185

186186
// ensure that instance_norm_beta will have shape compatible
@@ -219,7 +219,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
219219
const auto& instance_norm_gamma = pattern_map.at(instance_norm_gamma_m);
220220
if (instance_norm_gamma.get_element_type() != T)
221221
return false;
222-
if (instance_norm_gamma.get_partial_shape().rank().get_length() != num_groups)
222+
if (ov::shape_size(instance_norm_gamma.get_shape()) != static_cast<size_t>(num_groups))
223223
return false;
224224

225225
// ensure that instance_norm_gamma will have shape compatible

0 commit comments

Comments
 (0)