Skip to content

Commit 55b8e29

Browse files
wzt1997TaoLv
authored andcommitted
benchdnn: graph: support test for int4 data types with grouped quantization
1 parent 6f4f6fb commit 55b8e29

16 files changed

+705
-4
lines changed

tests/benchdnn/graph/deserialize.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ logical_tensor::data_type deserialized_lt::get_data_type() const {
8383
return logical_tensor::data_type::f8_e5m2;
8484
} else if (data_type_ == "f8_e4m3") {
8585
return logical_tensor::data_type::f8_e4m3;
86+
} else if (data_type_ == "s4") {
87+
return logical_tensor::data_type::s4;
88+
} else if (data_type_ == "u4") {
89+
return logical_tensor::data_type::u4;
8690
} else {
8791
return logical_tensor::data_type::undef;
8892
}

tests/benchdnn/graph/graph.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,11 @@ int doit(const prb_t *prb, res_t *res) {
524524
case logical_tensor::data_type::f8_e4m3:
525525
in_out_dt.emplace_back(dnnl_f8_e4m3);
526526
break;
527+
case logical_tensor::data_type::s4:
528+
in_out_dt.emplace_back(dnnl_s4);
529+
break;
530+
case logical_tensor::data_type::u4:
531+
in_out_dt.emplace_back(dnnl_u4);
527532
default: break;
528533
}
529534
}

tests/benchdnn/graph/setting_handler.cpp

+35-4
Original file line numberDiff line numberDiff line change
@@ -1676,6 +1676,12 @@ bool get_reorder_dt(const deserialized_op &base_op_ref, dnnl_data_type_t &sdt,
16761676
dnnl_data_type_t &ddt) {
16771677
sdt = convert_dt(base_op_ref.in_lts_.front().get_data_type());
16781678
ddt = convert_dt(base_op_ref.out_lts_.front().get_data_type());
1679+
1680+
const auto &op_kind = base_op_ref.kind_;
1681+
// As we always use f32 computation in the reference path, to link
1682+
// arguments correctly in the reference path, we need to always create
1683+
// dequantize ops with f32 output.
1684+
if (op_kind == "DynamicDequantize") { ddt = dnnl_f32; }
16791685
return true;
16801686
}
16811687

@@ -1704,10 +1710,16 @@ bool get_reorder_attrs(const deserialized_op &base_op_ref,
17041710
// scale
17051711
attr_t::policy_t scale_policy = attr_t::policy_t::COMMON;
17061712
int64_t axis = 1;
1713+
std::vector<dnnl_dim_t> groups;
1714+
dnnl_data_type_t scale_dt, zp_dt;
1715+
1716+
const int ndims
1717+
= static_cast<int>(base_op_ref.in_lts_.front().shape_.size());
1718+
base_op_ref.get_attr_s64(axis, "axis");
1719+
if (axis < 0) axis += ndims;
1720+
1721+
// per dimension
17071722
if (qtype == "per_channel") {
1708-
// per dimension
1709-
base_op_ref.get_attr_s64(axis, "axis");
1710-
const auto ndims = base_op_ref.in_lts_.front().shape_.size();
17111723
if (axis < 0) axis += ndims;
17121724
if (axis == 0) {
17131725
scale_policy = attr_t::PER_DIM_0;
@@ -1720,6 +1732,14 @@ bool get_reorder_attrs(const deserialized_op &base_op_ref,
17201732
} else {
17211733
assert(!"unsupported axis");
17221734
}
1735+
} else if (qtype == "per_group") {
1736+
scale_policy = attr_t::PER_TENSOR;
1737+
1738+
std::vector<int64_t> group_shape;
1739+
base_op_ref.get_attr_s64_vector(group_shape, "group_shape");
1740+
groups = {group_shape[ndims - 2], group_shape[ndims - 1]};
1741+
scale_dt = static_cast<dnnl_data_type_t>(
1742+
base_op_ref.in_lts_[1].get_data_type());
17231743
}
17241744

17251745
if (op_kind == "Dequantize" || op_kind == "Quantize") {
@@ -1734,18 +1754,29 @@ bool get_reorder_attrs(const deserialized_op &base_op_ref,
17341754
if (has_zps && !zps.empty())
17351755
zp.set(arg, attr_t::policy_t::COMMON, zps.front());
17361756
} else if (op_kind == "DynamicDequantize" || op_kind == "DynamicQuantize") {
1757+
// For reference path, it always use f32 for computation.
1758+
scale_dt = dnnl_f32;
1759+
17371760
// TODO: benchdnn needs to alloc memory based on is_def() function.
17381761
// so add tmp value for per_tensor scales && zps to make is_def()
17391762
// return false to alloc memory.
17401763
if (qtype == "per_tensor") {
17411764
arg_scales.set(arg, {scale_policy, 2});
1765+
} else if (qtype == "per_group") {
1766+
arg_scales.set(arg, {scale_policy, 1.f, scale_dt, groups});
17421767
} else {
17431768
arg_scales.set(arg, {scale_policy});
17441769
}
17451770
// zps is optional for DynamicDequantize/DynamicQuantize, default is
17461771
// symmetric quantization
17471772
if (base_op_ref.in_lts_.size() == 3) {
1748-
zp.set(arg, attr_t::policy_t::COMMON, 1);
1773+
if (qtype == "per_group") {
1774+
zp_dt = static_cast<dnnl_data_type_t>(
1775+
base_op_ref.in_lts_[2].get_data_type());
1776+
zp.set(arg, {scale_policy, 0, zp_dt, groups});
1777+
} else {
1778+
zp.set(arg, attr_t::policy_t::COMMON, 1);
1779+
}
17491780
}
17501781
}
17511782
return true;

tests/benchdnn/graph/utils.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ dnnl::graph::op::attr attrstr2kind(const std::string &attr_name) {
375375
{"axis", dnnl::graph::op::attr::axis},
376376
{"begin_norm_axis", dnnl::graph::op::attr::begin_norm_axis},
377377
{"groups", dnnl::graph::op::attr::groups},
378+
{"group_shape", dnnl::graph::op::attr::group_shape},
378379
// int64_t vector attributes. The value of these attributes can be a
379380
// vector of int64 numbers.
380381
{"axes", dnnl::graph::op::attr::axes},
@@ -1259,6 +1260,8 @@ dnnl_data_type_t convert_dt(const dnnl::graph::logical_tensor::data_type dt) {
12591260
case graph_dt::boolean: return dnnl_u8;
12601261
case graph_dt::f8_e5m2: return dnnl_f8_e5m2;
12611262
case graph_dt::f8_e4m3: return dnnl_f8_e4m3;
1263+
case graph_dt::s4: return dnnl_s4;
1264+
case graph_dt::u4: return dnnl_u4;
12621265
case graph_dt::undef:
12631266
default: return dnnl_data_type_undef;
12641267
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
{
2+
"version": "3.7.0",
3+
"engine_kind": "cpu",
4+
"fpmath_mode": "bf16",
5+
"fpmath_mode_apply_to_int": "true",
6+
"graph": [
7+
{
8+
"id": 0,
9+
"name": "aten::dequantize",
10+
"kind": "DynamicDequantize",
11+
"attrs": {
12+
"qtype": {
13+
"type": "string",
14+
"value": "per_group"
15+
},
16+
"group_shape": {
17+
"type": "s64[]",
18+
"value": [
19+
1,
20+
1,
21+
8,
22+
1
23+
]
24+
},
25+
"axis": {
26+
"type": "s64",
27+
"value": 2
28+
}
29+
},
30+
"inputs": [
31+
{
32+
"id": 0,
33+
"dtype": "s4",
34+
"shape": [
35+
1,
36+
32,
37+
128,
38+
32
39+
],
40+
"stride": [
41+
131072,
42+
4096,
43+
1,
44+
128
45+
],
46+
"layout_type": "strided",
47+
"property_type": "variable"
48+
},
49+
{
50+
"id": 1,
51+
"dtype": "bf16",
52+
"shape": [
53+
1,
54+
32,
55+
16,
56+
32
57+
],
58+
"stride": [
59+
16384,
60+
512,
61+
32,
62+
1
63+
],
64+
"layout_type": "strided",
65+
"property_type": "undef"
66+
},
67+
{
68+
"id": 2,
69+
"dtype": "s4",
70+
"shape": [
71+
1,
72+
32,
73+
16,
74+
32
75+
],
76+
"stride": [
77+
16384,
78+
512,
79+
32,
80+
1
81+
],
82+
"layout_type": "strided",
83+
"property_type": "undef"
84+
}
85+
],
86+
"outputs": [
87+
{
88+
"id": 10,
89+
"dtype": "bf16",
90+
"shape": [
91+
1,
92+
32,
93+
128,
94+
32
95+
],
96+
"stride": [
97+
131072,
98+
4096,
99+
32,
100+
1
101+
],
102+
"layout_type": "strided",
103+
"property_type": "variable"
104+
}
105+
]
106+
}
107+
]
108+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
{
2+
"version": "3.7.0",
3+
"engine_kind": "cpu",
4+
"fpmath_mode": "bf16",
5+
"fpmath_mode_apply_to_int": "true",
6+
"graph": [
7+
{
8+
"id": 0,
9+
"name": "aten::dequantize",
10+
"kind": "DynamicDequantize",
11+
"attrs": {
12+
"qtype": {
13+
"type": "string",
14+
"value": "per_group"
15+
},
16+
"group_shape": {
17+
"type": "s64[]",
18+
"value": [
19+
1,
20+
1,
21+
8,
22+
1
23+
]
24+
},
25+
"axis": {
26+
"type": "s64",
27+
"value": 2
28+
}
29+
},
30+
"inputs": [
31+
{
32+
"id": 0,
33+
"dtype": "u4",
34+
"shape": [
35+
1,
36+
32,
37+
128,
38+
32
39+
],
40+
"stride": [
41+
131072,
42+
4096,
43+
1,
44+
128
45+
],
46+
"layout_type": "strided",
47+
"property_type": "variable"
48+
},
49+
{
50+
"id": 1,
51+
"dtype": "bf16",
52+
"shape": [
53+
1,
54+
32,
55+
16,
56+
32
57+
],
58+
"stride": [
59+
16384,
60+
512,
61+
32,
62+
1
63+
],
64+
"layout_type": "strided",
65+
"property_type": "undef"
66+
},
67+
{
68+
"id": 2,
69+
"dtype": "u4",
70+
"shape": [
71+
1,
72+
32,
73+
16,
74+
32
75+
],
76+
"stride": [
77+
16384,
78+
512,
79+
32,
80+
1
81+
],
82+
"layout_type": "strided",
83+
"property_type": "undef"
84+
}
85+
],
86+
"outputs": [
87+
{
88+
"id": 10,
89+
"dtype": "bf16",
90+
"shape": [
91+
1,
92+
32,
93+
128,
94+
32
95+
],
96+
"stride": [
97+
131072,
98+
4096,
99+
32,
100+
1
101+
],
102+
"layout_type": "strided",
103+
"property_type": "variable"
104+
}
105+
]
106+
}
107+
]
108+
}

0 commit comments

Comments
 (0)