@@ -172,12 +172,99 @@ void LayerTestsCommon::Compare(const std::vector<std::pair<ngraph::element::Type
172
172
}
173
173
}
174
174
175
+ inline void callCompareBool (const std::pair<ngraph::element::Type, std::vector<std::uint8_t >> &expected,
176
+ const uint8_t * actualBuffer, size_t size, float threshold, float abs_threshold) {
177
+ auto expectedBuffer = expected.second .data ();
178
+ #define CASE (X ) \
179
+ case X: { \
180
+ auto typedExpectedBuffer = reinterpret_cast <const ov::element_type_traits<X>::value_type*>(expectedBuffer); \
181
+ for (size_t i = 0 ; i < size; ++i) { \
182
+ ASSERT_EQ (static_cast <bool >(actualBuffer[i]), static_cast <bool >(typedExpectedBuffer[i])) << \
183
+ " Comparison of bool values failed at index: " << i << " expected value: " << \
184
+ (static_cast <bool >(expectedBuffer[i]) ? " True" : " False" ) << \
185
+ " and actual value: " << \
186
+ (static_cast <bool >(actualBuffer[i]) ? " True" : " False" ); \
187
+ } \
188
+ break ; \
189
+ }
190
+
191
+ switch (expected.first ) {
192
+ CASE (ngraph::element::Type_t::boolean)
193
+ CASE (ngraph::element::Type_t::u8)
194
+ CASE (ngraph::element::Type_t::i8)
195
+ CASE (ngraph::element::Type_t::u16)
196
+ CASE (ngraph::element::Type_t::i16)
197
+ CASE (ngraph::element::Type_t::u32)
198
+ CASE (ngraph::element::Type_t::i32)
199
+ CASE (ngraph::element::Type_t::u64)
200
+ CASE (ngraph::element::Type_t::i64)
201
+ CASE (ngraph::element::Type_t::bf16)
202
+ CASE (ngraph::element::Type_t::f16)
203
+ CASE (ngraph::element::Type_t::f32)
204
+ CASE (ngraph::element::Type_t::f64)
205
+ case ngraph::element::Type_t::i4: {
206
+ auto expectedOut = ngraph::helpers::convertOutputPrecision (
207
+ expected.second ,
208
+ expected.first ,
209
+ ngraph::element::Type_t::i8,
210
+ size);
211
+ for (size_t i = 0 ; i < size; ++i) {
212
+ ASSERT_EQ (static_cast <bool >(actualBuffer[i]), static_cast <bool >(expectedOut[i])) <<
213
+ " Comparison of bool values failed at index: " << i << " expected value: " <<
214
+ (static_cast <bool >(expectedOut[i]) ? " True" : " False" ) <<
215
+ " and actual value: " <<
216
+ (static_cast <bool >(actualBuffer[i]) ? " True" : " False" );
217
+ }
218
+ break ;
219
+ }
220
+ case ngraph::element::Type_t::u4: {
221
+ auto expectedOut = ngraph::helpers::convertOutputPrecision (
222
+ expected.second ,
223
+ expected.first ,
224
+ ngraph::element::Type_t::u8,
225
+ size);
226
+ for (size_t i = 0 ; i < size; ++i) {
227
+ ASSERT_EQ (static_cast <bool >(actualBuffer[i]), static_cast <bool >(expectedOut[i])) <<
228
+ " Comparison of bool values failed at index: " << i << " expected value: " <<
229
+ (static_cast <bool >(expectedOut[i]) ? " True" : " False" ) <<
230
+ " and actual value: " <<
231
+ (static_cast <bool >(actualBuffer[i]) ? " True" : " False" );
232
+ }
233
+ break ;
234
+ }
235
+ case ngraph::element::Type_t::dynamic:
236
+ case ngraph::element::Type_t::undefined: {
237
+ auto typedExpectedBuffer = reinterpret_cast <const uint8_t *>(expectedBuffer);
238
+ for (size_t i = 0 ; i < size; ++i) {
239
+ ASSERT_EQ (static_cast <bool >(actualBuffer[i]), static_cast <bool >(typedExpectedBuffer[i])) <<
240
+ " Comparison of bool values failed at index: " << i << " expected value: " <<
241
+ (static_cast <bool >(typedExpectedBuffer[i]) ? " True" : " False" ) <<
242
+ " and actual value: " <<
243
+ (static_cast <bool >(actualBuffer[i]) ? " True" : " False" );
244
+ }
245
+ break ;
246
+ }
247
+ default :
248
+ FAIL () << " Comparator for " << expected.first << " precision isn't supported" ;
249
+ }
250
+ #undef CASE
251
+ return ;
252
+ }
253
+
175
254
template <typename T_IE>
176
255
inline void callCompare (const std::pair<ngraph::element::Type, std::vector<std::uint8_t >> &expected,
177
256
const T_IE* actualBuffer, size_t size, float threshold, float abs_threshold) {
178
257
auto expectedBuffer = expected.second .data ();
179
258
switch (expected.first ) {
180
259
case ngraph::element::Type_t::boolean:
260
+ for (size_t i = 0 ; i < size; ++i) {
261
+ ASSERT_EQ (static_cast <bool >(actualBuffer[i]), static_cast <bool >(expectedBuffer[i])) <<
262
+ " Comparison of bool values failed at index: " << i << " expected value: " <<
263
+ (static_cast <bool >(expectedBuffer[i]) ? " True" : " False" ) <<
264
+ " and actual value: " <<
265
+ (static_cast <bool >(actualBuffer[i]) ? " True" : " False" );
266
+ }
267
+ break ;
181
268
case ngraph::element::Type_t::u8:
182
269
LayerTestsCommon::Compare<T_IE, uint8_t >(reinterpret_cast <const uint8_t *>(expectedBuffer),
183
270
actualBuffer, size, threshold, abs_threshold);
@@ -277,6 +364,8 @@ void LayerTestsCommon::Compare(const std::pair<ngraph::element::Type, std::vecto
277
364
const auto &size = actual->size ();
278
365
switch (precision) {
279
366
case InferenceEngine::Precision::BOOL:
367
+ callCompareBool (expected, reinterpret_cast <const uint8_t *>(actualBuffer), size, threshold, abs_threshold);
368
+ break ;
280
369
case InferenceEngine::Precision::U8:
281
370
callCompare<uint8_t >(expected, reinterpret_cast <const uint8_t *>(actualBuffer), size, threshold, abs_threshold);
282
371
break ;
0 commit comments