@@ -72,8 +72,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
72
72
73
73
const auto & pre_mvn_reshape_out_ps = pattern_map.at (pre_mvn_reshape_m).get_partial_shape ();
74
74
75
- const size_t num_channels = static_cast < size_t >( input_ps[1 ].get_max_length () );
76
- const size_t num_groups = static_cast < size_t >( pre_mvn_reshape_out_ps[1 ].get_max_length () );
75
+ const auto num_channels = input_ps[1 ].get_max_length ();
76
+ const auto num_groups = pre_mvn_reshape_out_ps[1 ].get_max_length ();
77
77
78
78
// we expect to reshape input in a way that would merge all spatial dimensions
79
79
// but leave batch and channel dimensions untouched
@@ -142,13 +142,13 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
142
142
const auto & group_norm_gamma = pattern_map.at (group_norm_gamma_m);
143
143
if (group_norm_gamma.get_element_type () != T)
144
144
return false ;
145
- if (ov::shape_size ( group_norm_gamma.get_shape () ) != num_channels)
145
+ if (group_norm_gamma.get_partial_shape (). rank (). get_length ( ) != num_channels)
146
146
return false ;
147
147
148
148
const auto & group_norm_beta = pattern_map.at (group_norm_beta_m);
149
149
if (group_norm_beta.get_element_type () != T)
150
150
return false ;
151
- if (ov::shape_size ( group_norm_beta.get_shape () ) != num_channels)
151
+ if (group_norm_beta.get_partial_shape (). rank (). get_length ( ) != num_channels)
152
152
return false ;
153
153
154
154
ov::NodeVector nodes;
@@ -157,7 +157,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
157
157
nodes.push_back (group_norm_gamma_1d_m);
158
158
const auto & group_norm_gamma_1d_out_ps = group_norm_gamma_1d_m->get_output_partial_shape (0 );
159
159
160
- auto expected_param_shape = ov::PartialShape ({static_cast <ov::Dimension>( num_channels) });
160
+ auto expected_param_shape = ov::PartialShape ({num_channels});
161
161
if (group_norm_gamma_1d_out_ps != expected_param_shape)
162
162
return false ;
163
163
@@ -171,16 +171,16 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
171
171
auto gather_axis_const_m = op::v0::Constant::create (element::i64, Shape{1 }, {0 });
172
172
nodes.push_back (gather_axis_const_m);
173
173
auto gather_indices_vals = std::vector<int64_t >();
174
- for (auto i = 0ull ; i < num_groups; i++)
174
+ for (auto i = 0 ; i < num_groups; i++)
175
175
gather_indices_vals.insert (gather_indices_vals.end (), num_channels / num_groups, i);
176
- auto gather_indices_const_m = op::v0::Constant::create (element::i64, Shape{num_channels}, gather_indices_vals);
176
+ auto gather_indices_const_m = op::v0::Constant::create (element::i64, Shape{static_cast < size_t >( num_channels) }, gather_indices_vals);
177
177
nodes.push_back (gather_indices_const_m);
178
178
179
179
if (pattern_map.count (instance_norm_beta_m) > 0 ) {
180
180
const auto & instance_norm_beta = pattern_map.at (instance_norm_beta_m);
181
181
if (instance_norm_beta.get_element_type () != T)
182
182
return false ;
183
- if (ov::shape_size ( instance_norm_beta.get_shape () ) != num_groups)
183
+ if (instance_norm_beta.get_partial_shape (). rank (). get_length ( ) != num_groups)
184
184
return false ;
185
185
186
186
// ensure that instance_norm_beta will have shape compatible
@@ -219,7 +219,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
219
219
const auto & instance_norm_gamma = pattern_map.at (instance_norm_gamma_m);
220
220
if (instance_norm_gamma.get_element_type () != T)
221
221
return false ;
222
- if (ov::shape_size ( instance_norm_gamma.get_shape () ) != num_groups)
222
+ if (instance_norm_gamma.get_partial_shape (). rank (). get_length ( ) != num_groups)
223
223
return false ;
224
224
225
225
// ensure that instance_norm_gamma will have shape compatible
0 commit comments