Skip to content

Commit bc911a5

Browse files
wip
1 parent 871ab4a commit bc911a5

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
7272

7373
const auto& pre_mvn_reshape_out_ps = pattern_map.at(pre_mvn_reshape_m).get_partial_shape();
7474

75-
const size_t num_channels = static_cast<size_t>(input_ps[1].get_max_length());
76-
const size_t num_groups = static_cast<size_t>(pre_mvn_reshape_out_ps[1].get_max_length());
75+
const auto num_channels = input_ps[1].get_max_length();
76+
const auto num_groups = pre_mvn_reshape_out_ps[1].get_max_length();
7777

7878
// we expect to reshape input in a way that would merge all spatial dimensions
7979
// but leave batch and channel dimensions untouched
@@ -142,13 +142,13 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
142142
const auto& group_norm_gamma = pattern_map.at(group_norm_gamma_m);
143143
if (group_norm_gamma.get_element_type() != T)
144144
return false;
145-
if (ov::shape_size(group_norm_gamma.get_shape()) != num_channels)
145+
if (group_norm_gamma.get_partial_shape().rank().get_length() != num_channels)
146146
return false;
147147

148148
const auto& group_norm_beta = pattern_map.at(group_norm_beta_m);
149149
if (group_norm_beta.get_element_type() != T)
150150
return false;
151-
if (ov::shape_size(group_norm_beta.get_shape()) != num_channels)
151+
if (group_norm_beta.get_partial_shape().rank().get_length() != num_channels)
152152
return false;
153153

154154
ov::NodeVector nodes;
@@ -157,7 +157,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
157157
nodes.push_back(group_norm_gamma_1d_m);
158158
const auto& group_norm_gamma_1d_out_ps = group_norm_gamma_1d_m->get_output_partial_shape(0);
159159

160-
auto expected_param_shape = ov::PartialShape({static_cast<ov::Dimension>(num_channels)});
160+
auto expected_param_shape = ov::PartialShape({num_channels});
161161
if (group_norm_gamma_1d_out_ps != expected_param_shape)
162162
return false;
163163

@@ -171,16 +171,16 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
171171
auto gather_axis_const_m = op::v0::Constant::create(element::i64, Shape{1}, {0});
172172
nodes.push_back(gather_axis_const_m);
173173
auto gather_indices_vals = std::vector<int64_t>();
174-
for (auto i = 0ull; i < num_groups; i++)
174+
for (auto i = 0; i < num_groups; i++)
175175
gather_indices_vals.insert(gather_indices_vals.end(), num_channels / num_groups, i);
176-
auto gather_indices_const_m = op::v0::Constant::create(element::i64, Shape{num_channels}, gather_indices_vals);
176+
auto gather_indices_const_m = op::v0::Constant::create(element::i64, Shape{static_cast<size_t>(num_channels)}, gather_indices_vals);
177177
nodes.push_back(gather_indices_const_m);
178178

179179
if (pattern_map.count(instance_norm_beta_m) > 0) {
180180
const auto& instance_norm_beta = pattern_map.at(instance_norm_beta_m);
181181
if (instance_norm_beta.get_element_type() != T)
182182
return false;
183-
if (ov::shape_size(instance_norm_beta.get_shape()) != num_groups)
183+
if (instance_norm_beta.get_partial_shape().rank().get_length() != num_groups)
184184
return false;
185185

186186
// ensure that instance_norm_beta will have shape compatible
@@ -219,7 +219,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
219219
const auto& instance_norm_gamma = pattern_map.at(instance_norm_gamma_m);
220220
if (instance_norm_gamma.get_element_type() != T)
221221
return false;
222-
if (ov::shape_size(instance_norm_gamma.get_shape()) != num_groups)
222+
if (instance_norm_gamma.get_partial_shape().rank().get_length() != num_groups)
223223
return false;
224224

225225
// ensure that instance_norm_gamma will have shape compatible

0 commit comments

Comments
 (0)