Skip to content

Commit f148732

Browse files
committed
tests: gtests: update quantization attr test with new supported cases
1 parent 64b717e commit f148732

File tree

1 file changed

+46
-6
lines changed

1 file changed

+46
-6
lines changed

tests/gtests/test_iface_attr_quantization.cpp

+46-6
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,17 @@ class attr_quantization_test_t : public ::testing::Test {
4040
return attr;
4141
}
4242

43-
static primitive_attr gen_attr_with_scales(int arg, int mask = 0) {
43+
static primitive_attr gen_attr_with_scales(int arg, int mask = 0,
44+
data_type dt = data_type::f32, const memory::dims &groups = {}) {
4445
primitive_attr attr;
45-
attr.set_scales_mask(arg, mask);
46+
attr.set_scales(arg, mask, groups, dt);
4647
return attr;
4748
}
4849

49-
static primitive_attr gen_attr_with_zp(int arg, int mask = 0) {
50+
static primitive_attr gen_attr_with_zp(int arg, int mask = 0,
51+
data_type dt = data_type::s32, const memory::dims &groups = {}) {
5052
primitive_attr attr;
51-
attr.set_zero_points_mask(arg, mask);
53+
attr.set_zero_points(arg, mask, groups, dt);
5254
return attr;
5355
}
5456

@@ -434,17 +436,55 @@ CPU_TEST_F(attr_quantization_test_t, TestMatmul) {
434436
// zpoints: common mask
435437
CHECK_OK(matmul::primitive_desc(
436438
eng, a_md, b_md, c_md, gen_attr_with_zp(arg)));
439+
// zpoints: per_oc mask
440+
CHECK_OK(matmul::primitive_desc(
441+
eng, a_md, b_md, c_md, gen_attr_with_zp(arg, 1 << 1)));
442+
// zpoints: per_ocic mask
443+
if (arg == DNNL_ARG_WEIGHTS) {
444+
CHECK_OK(matmul::primitive_desc(eng, a_md, b_md, c_md,
445+
gen_attr_with_zp(arg, (1 << 1) + (1 << 0))));
446+
CHECK_OK(matmul::primitive_desc(eng, a_md, b_md, c_md,
447+
gen_attr_with_zp(
448+
arg, (1 << 1) + (1 << 0), b_dt, {3, 1})));
449+
} else {
450+
CHECK_UNIMPL(matmul::primitive_desc(eng, a_md, b_md, c_md,
451+
gen_attr_with_zp(arg, (1 << 1) + (1 << 0))));
452+
}
437453
}
438454
// scales: common mask
439455
CHECK_OK(matmul::primitive_desc(
440456
eng, a_md, b_md, c_md, gen_attr_with_scales(arg)));
441457
// scales: per_oc mask
442-
if (arg == DNNL_ARG_WEIGHTS)
458+
if (arg == DNNL_ARG_WEIGHTS) {
443459
CHECK_OK(matmul::primitive_desc(eng, a_md, b_md, c_md,
444460
gen_attr_with_scales(arg, 1 << 1)));
445-
else
461+
CHECK_OK(matmul::primitive_desc(eng, a_md, b_md, c_md,
462+
gen_attr_with_scales(arg, (1 << 1) + (1 << 0))));
463+
if (b_dt == data_type::s8) {
464+
CHECK_OK(matmul::primitive_desc(eng, a_md, b_md, c_md,
465+
gen_attr_with_scales(arg, (1 << 1) + (1 << 0),
466+
data_type::f32, {3, 1})));
467+
} else {
468+
CHECK_UNIMPL(matmul::primitive_desc(eng, a_md, b_md, c_md,
469+
gen_attr_with_scales(arg, (1 << 1) + (1 << 0),
470+
data_type::f32, {3, 1})));
471+
}
472+
} else if (arg == DNNL_ARG_SRC) {
446473
CHECK_UNIMPL(matmul::primitive_desc(eng, a_md, b_md, c_md,
447474
gen_attr_with_scales(arg, 1 << 1)));
475+
if (a_dt == data_type::u8) {
476+
CHECK_OK(matmul::primitive_desc(eng, a_md, b_md, c_md,
477+
gen_attr_with_scales(
478+
arg, 1 << 1, data_type::f32, {1, 3})));
479+
} else {
480+
CHECK_UNIMPL(matmul::primitive_desc(eng, a_md, b_md, c_md,
481+
gen_attr_with_scales(
482+
arg, 1 << 1, data_type::f32, {1, 3})));
483+
}
484+
} else {
485+
CHECK_UNIMPL(matmul::primitive_desc(eng, a_md, b_md, c_md,
486+
gen_attr_with_scales(arg, 1 << 1)));
487+
}
448488
//scales: unsupported mask
449489
CHECK_UNIMPL(matmul::primitive_desc(
450490
eng, a_md, b_md, c_md, gen_attr_with_scales(arg, 1 << 2)));

0 commit comments

Comments
 (0)