@@ -40,15 +40,17 @@ class attr_quantization_test_t : public ::testing::Test {
40
40
return attr;
41
41
}
42
42
43
- static primitive_attr gen_attr_with_scales (int arg, int mask = 0 ) {
43
+ static primitive_attr gen_attr_with_scales (int arg, int mask = 0 ,
44
+ data_type dt = data_type::f32, const memory::dims &groups = {}) {
44
45
primitive_attr attr;
45
- attr.set_scales_mask (arg, mask);
46
+ attr.set_scales (arg, mask, groups, dt );
46
47
return attr;
47
48
}
48
49
49
- static primitive_attr gen_attr_with_zp (int arg, int mask = 0 ) {
50
+ static primitive_attr gen_attr_with_zp (int arg, int mask = 0 ,
51
+ data_type dt = data_type::s32, const memory::dims &groups = {}) {
50
52
primitive_attr attr;
51
- attr.set_zero_points_mask (arg, mask);
53
+ attr.set_zero_points (arg, mask, groups, dt );
52
54
return attr;
53
55
}
54
56
@@ -434,17 +436,55 @@ CPU_TEST_F(attr_quantization_test_t, TestMatmul) {
434
436
// zpoints: common mask
435
437
CHECK_OK (matmul::primitive_desc (
436
438
eng, a_md, b_md, c_md, gen_attr_with_zp (arg)));
439
+ // zpoints: per_oc mask
440
+ CHECK_OK (matmul::primitive_desc (
441
+ eng, a_md, b_md, c_md, gen_attr_with_zp (arg, 1 << 1 )));
442
+ // zpoints: per_ocic mask
443
+ if (arg == DNNL_ARG_WEIGHTS) {
444
+ CHECK_OK (matmul::primitive_desc (eng, a_md, b_md, c_md,
445
+ gen_attr_with_zp (arg, (1 << 1 ) + (1 << 0 ))));
446
+ CHECK_OK (matmul::primitive_desc (eng, a_md, b_md, c_md,
447
+ gen_attr_with_zp (
448
+ arg, (1 << 1 ) + (1 << 0 ), b_dt, {3 , 1 })));
449
+ } else {
450
+ CHECK_UNIMPL (matmul::primitive_desc (eng, a_md, b_md, c_md,
451
+ gen_attr_with_zp (arg, (1 << 1 ) + (1 << 0 ))));
452
+ }
437
453
}
438
454
// scales: common mask
439
455
CHECK_OK (matmul::primitive_desc (
440
456
eng, a_md, b_md, c_md, gen_attr_with_scales (arg)));
441
457
// scales: per_oc mask
442
- if (arg == DNNL_ARG_WEIGHTS)
458
+ if (arg == DNNL_ARG_WEIGHTS) {
443
459
CHECK_OK (matmul::primitive_desc (eng, a_md, b_md, c_md,
444
460
gen_attr_with_scales (arg, 1 << 1 )));
445
- else
461
+ CHECK_OK (matmul::primitive_desc (eng, a_md, b_md, c_md,
462
+ gen_attr_with_scales (arg, (1 << 1 ) + (1 << 0 ))));
463
+ if (b_dt == data_type::s8) {
464
+ CHECK_OK (matmul::primitive_desc (eng, a_md, b_md, c_md,
465
+ gen_attr_with_scales (arg, (1 << 1 ) + (1 << 0 ),
466
+ data_type::f32, {3 , 1 })));
467
+ } else {
468
+ CHECK_UNIMPL (matmul::primitive_desc (eng, a_md, b_md, c_md,
469
+ gen_attr_with_scales (arg, (1 << 1 ) + (1 << 0 ),
470
+ data_type::f32, {3 , 1 })));
471
+ }
472
+ } else if (arg == DNNL_ARG_SRC) {
446
473
CHECK_UNIMPL (matmul::primitive_desc (eng, a_md, b_md, c_md,
447
474
gen_attr_with_scales (arg, 1 << 1 )));
475
+ if (a_dt == data_type::u8) {
476
+ CHECK_OK (matmul::primitive_desc (eng, a_md, b_md, c_md,
477
+ gen_attr_with_scales (
478
+ arg, 1 << 1 , data_type::f32, {1 , 3 })));
479
+ } else {
480
+ CHECK_UNIMPL (matmul::primitive_desc (eng, a_md, b_md, c_md,
481
+ gen_attr_with_scales (
482
+ arg, 1 << 1 , data_type::f32, {1 , 3 })));
483
+ }
484
+ } else {
485
+ CHECK_UNIMPL (matmul::primitive_desc (eng, a_md, b_md, c_md,
486
+ gen_attr_with_scales (arg, 1 << 1 )));
487
+ }
448
488
// scales: unsupported mask
449
489
CHECK_UNIMPL (matmul::primitive_desc (
450
490
eng, a_md, b_md, c_md, gen_attr_with_scales (arg, 1 << 2 )));
0 commit comments