Skip to content

Commit ad8dedd

Browse files
[TRANSFORMATIONS] Fix potential integer underflow (openvinotoolkit#29476)
Fix potential integer underflow ### Tickets: - [CVS-164224](https://jira.devtools.intel.com/browse/CVS-164224) Signed-off-by: Andrii Staikov <andrii.staikov@intel.com>
1 parent 7af7a1e commit ad8dedd

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp

+10-9
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
7272

7373
const auto& pre_mvn_reshape_out_ps = pattern_map.at(pre_mvn_reshape_m).get_partial_shape();
7474

75-
const size_t num_channels = static_cast<size_t>(input_ps[1].get_max_length());
76-
const size_t num_groups = static_cast<size_t>(pre_mvn_reshape_out_ps[1].get_max_length());
75+
const auto num_channels = input_ps[1].get_max_length();
76+
const auto num_groups = pre_mvn_reshape_out_ps[1].get_max_length();
7777

7878
// we expect to reshape input in a way that would merge all spatial dimensions
7979
// but leave batch and channel dimensions untouched
@@ -142,13 +142,13 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
142142
const auto& group_norm_gamma = pattern_map.at(group_norm_gamma_m);
143143
if (group_norm_gamma.get_element_type() != T)
144144
return false;
145-
if (ov::shape_size(group_norm_gamma.get_shape()) != num_channels)
145+
if (ov::shape_size(group_norm_gamma.get_shape()) != static_cast<size_t>(num_channels))
146146
return false;
147147

148148
const auto& group_norm_beta = pattern_map.at(group_norm_beta_m);
149149
if (group_norm_beta.get_element_type() != T)
150150
return false;
151-
if (ov::shape_size(group_norm_beta.get_shape()) != num_channels)
151+
if (ov::shape_size(group_norm_beta.get_shape()) != static_cast<size_t>(num_channels))
152152
return false;
153153

154154
ov::NodeVector nodes;
@@ -157,7 +157,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
157157
nodes.push_back(group_norm_gamma_1d_m);
158158
const auto& group_norm_gamma_1d_out_ps = group_norm_gamma_1d_m->get_output_partial_shape(0);
159159

160-
auto expected_param_shape = ov::PartialShape({static_cast<ov::Dimension>(num_channels)});
160+
auto expected_param_shape = ov::PartialShape({num_channels});
161161
if (group_norm_gamma_1d_out_ps != expected_param_shape)
162162
return false;
163163

@@ -171,16 +171,17 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
171171
auto gather_axis_const_m = op::v0::Constant::create(element::i64, Shape{1}, {0});
172172
nodes.push_back(gather_axis_const_m);
173173
auto gather_indices_vals = std::vector<int64_t>();
174-
for (auto i = 0ull; i < num_groups; i++)
174+
for (auto i = 0; i < num_groups; i++)
175175
gather_indices_vals.insert(gather_indices_vals.end(), num_channels / num_groups, i);
176-
auto gather_indices_const_m = op::v0::Constant::create(element::i64, Shape{num_channels}, gather_indices_vals);
176+
auto gather_indices_const_m =
177+
op::v0::Constant::create(element::i64, Shape{static_cast<size_t>(num_channels)}, gather_indices_vals);
177178
nodes.push_back(gather_indices_const_m);
178179

179180
if (pattern_map.count(instance_norm_beta_m) > 0) {
180181
const auto& instance_norm_beta = pattern_map.at(instance_norm_beta_m);
181182
if (instance_norm_beta.get_element_type() != T)
182183
return false;
183-
if (ov::shape_size(instance_norm_beta.get_shape()) != num_groups)
184+
if (ov::shape_size(instance_norm_beta.get_shape()) != static_cast<size_t>(num_groups))
184185
return false;
185186

186187
// ensure that instance_norm_beta will have shape compatible
@@ -219,7 +220,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
219220
const auto& instance_norm_gamma = pattern_map.at(instance_norm_gamma_m);
220221
if (instance_norm_gamma.get_element_type() != T)
221222
return false;
222-
if (ov::shape_size(instance_norm_gamma.get_shape()) != num_groups)
223+
if (ov::shape_size(instance_norm_gamma.get_shape()) != static_cast<size_t>(num_groups))
223224
return false;
224225

225226
// ensure that instance_norm_gamma will have shape compatible

0 commit comments

Comments
 (0)