Skip to content

Commit 9ae372b

Browse files
committed
benchdnn: pool: move from --cfg option to --dt
1 parent cd4ff0d commit 9ae372b

File tree

9 files changed

+246
-176
lines changed

9 files changed

+246
-176
lines changed

tests/benchdnn/doc/driver_pool.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@ where *pool-knobs* are:
99

1010
- `--dir={FWD_D [default], FWD_I, BWD_D}` -- dnnl_prop_kind_t.
1111
Refer to [direction](knobs_dir.md) for details.
12-
- `--cfg={f32 [default], ...}` -- Refer to ``Configurations`` below.
12+
- `--dt={f32:f32:f32 [default], ...}` -- source, weights and destination data
13+
types. Interface supports broadcasting, when a single input is
14+
provided, e.g., `--dt=f32`, and the value will be applied for all
15+
tensors. Refer to [data types](knobs_dt.md) for details.
16+
- `--cfg={f32 [default], ...}` -- Deprecated setting.
17+
Refer to ``Configurations`` below.
1318
- `--tag={nchw [default], ...}` -- physical src and dst memory layout.
1419
Refer to [tags](knobs_tag.md) for details.
1520
- `--alg={max [default], avg_np, avg_p}` -- pooling algorithm.
@@ -89,7 +94,7 @@ Run a named problem with single precision src/dst, iterating by:
8994
3) all algorithm combinations,
9095
4) using default minibatch of 96 and 5:
9196
``` sh
92-
./benchdnn --pool --cfg=f32 --tag=nChw8c,nChw16c \
97+
./benchdnn --pool --dt=f32 --tag=nChw8c,nChw16c \
9398
--dir=FWD_D,FWD_I,BWD_D --alg=max,avg_np,avg_p --mb=0,5 \
9499
mb96ic768_ih17oh17_kh3sh1ph1n"googlenet_v3:ave_pool_mixed_4_pool"
95100
```

tests/benchdnn/graph/setting_handler.cpp

+8-16
Original file line numberDiff line numberDiff line change
@@ -1297,10 +1297,9 @@ bool get_pool_dir(const deserialized_op &base_op_ref, dir_t &dir) {
12971297
return ret;
12981298
}
12991299

1300-
bool get_pool_cfg(const deserialized_op &base_op_ref,
1301-
const ::pool::dt_conf_t **cfg,
1300+
bool get_pool_dt(const deserialized_op &base_op_ref,
1301+
std::vector<dnnl_data_type_t> &dt,
13021302
const std::unordered_set<size_t> &rewrite_lt_ids) {
1303-
std::string cfg_str {"f32"};
13041303
auto src_dt = base_op_ref.in_lts_[0].data_type_;
13051304
auto dst_dt = base_op_ref.out_lts_[0].data_type_;
13061305
if (rewrite_lt_ids.find(base_op_ref.in_lts_[0].id_) != rewrite_lt_ids.end())
@@ -1309,15 +1308,9 @@ bool get_pool_cfg(const deserialized_op &base_op_ref,
13091308
!= rewrite_lt_ids.end())
13101309
dst_dt = "f32";
13111310

1312-
if (src_dt == dst_dt
1313-
&& ((src_dt == "f32" || src_dt == "f16" || src_dt == "bf16"))) {
1314-
// if ((src_dt == "f16" || src_dt == "bf16") && !is_gpu()) return false;
1315-
// temporarily removed, will add the check later for all drivers
1316-
cfg_str = src_dt;
1317-
*cfg = ::pool::str2cfg(cfg_str.c_str());
1318-
return true;
1319-
}
1320-
return false;
1311+
dt = {convert_dt(get_data_type(src_dt)), convert_dt(get_data_type(dst_dt))};
1312+
1313+
return true;
13211314
}
13221315

13231316
bool get_pool_alg(const deserialized_op &base_op_ref, ::pool::alg_t &alg) {
@@ -1356,8 +1349,8 @@ ::pool::settings_t get_setting(const deserialized_op &base_op_ref,
13561349
pool::get_pool_alg(base_op_ref, op_setting.alg.front()), res);
13571350
DNN_GRAPH_CHECK_SETTINGS(
13581351
pool::get_pool_dir(base_op_ref, op_setting.dir.front()), res);
1359-
DNN_GRAPH_CHECK_SETTINGS(pool::get_pool_cfg(base_op_ref,
1360-
&op_setting.cfg.front(), rewrite_lt_ids),
1352+
DNN_GRAPH_CHECK_SETTINGS(pool::get_pool_dt(base_op_ref,
1353+
op_setting.dt.front(), rewrite_lt_ids),
13611354
res);
13621355
DNN_GRAPH_CHECK_SETTINGS(
13631356
get_driver_tag(base_op_ref, op_setting.tag.front()), res);
@@ -1370,9 +1363,8 @@ void set_s8u8_for_prb(::pool::prb_t *prb,
13701363
res_t *res) {
13711364
std::string cfg_str;
13721365
for (size_t offset = 0; offset < map_off_to_dt.size(); offset++) {
1373-
cfg_str += map_off_to_dt.at(offset);
1366+
prb->dt[offset] = convert_dt(get_data_type(map_off_to_dt.at(offset)));
13741367
}
1375-
prb->cfg = ::pool::str2cfg(cfg_str.c_str());
13761368
}
13771369

13781370
} //namespace pool

tests/benchdnn/inputs/pool/test_pool_ci

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
--alg=max,avg_np,avg_p
55
# Training
6-
--cfg=f32,bf16
6+
--dt=f32,bf16
77
--dir=FWD_D,BWD_D
88
--tag=abx,axb
99
--batch=shapes_basic
@@ -13,11 +13,11 @@
1313
--tag=axb
1414

1515
## All inference configs
16-
--cfg=f32,bf16,f16,s32,s8,u8, \
17-
s8u8,u8s8,s8f32,f32s8,u8f32,f32u8,s8f16,f16s8,u8f16,f16u8
16+
--dt=f32,bf16,f16,s32,s8,u8, \
17+
s8:u8,u8:s8,s8:f32,f32:s8,u8:f32,f32:u8,s8:f16,f16:s8,u8:f16,f16:u8
1818
--batch=shapes_basic
1919

2020
## Attributes
21-
--cfg=f32,bf16,f16,s32,s8,u8
21+
--dt=f32,bf16,f16,s32,s8,u8
2222
--attr-post-ops=add:f32:per_oc,linear:0.5:-1
2323
--batch=shapes_basic

tests/benchdnn/pool/bench_pool.cpp

+37-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ using driver_task_executor_t = task_executor_t<prb_t, perf_report_t,
4040
void check_correctness(
4141
const settings_t &s, driver_task_executor_t &task_executor) {
4242
for_(const auto &i_dir : s.dir)
43+
for_(const auto &i_dt_ : s.dt)
4344
for_(const auto &i_cfg : s.cfg)
4445
for_(const auto &i_tag : s.tag)
4546
for_(const auto &i_alg : s.alg)
@@ -50,7 +51,12 @@ void check_correctness(
5051
for (const auto &i_ctx_exe : s.ctx_exe) {
5152
auto attr = settings_t::get_attr(i_post_ops, i_scratchpad_mode);
5253

53-
const prb_t prb(s.desc, i_dir, i_cfg, i_tag, i_alg, attr, i_ctx_init,
54+
auto i_dt = i_dt_;
55+
if (!i_cfg.empty() && i_dt.size() == 1 && i_dt[0] == dnnl_f32) {
56+
handle_legacy_cfg(i_dt, i_cfg);
57+
}
58+
59+
const prb_t prb(s.desc, i_dir, i_dt, i_tag, i_alg, attr, i_ctx_init,
5460
i_ctx_exe, i_mb);
5561
if (s.pattern && !match_regex(prb.str(), s.pattern)) return;
5662

@@ -59,6 +65,33 @@ void check_correctness(
5965
}
6066
}
6167

68+
int verify_input(const settings_t &s) {
69+
for_(const auto &i_dt : s.dt)
70+
for (const auto &i_cfg : s.cfg) {
71+
if (i_cfg.empty()) continue;
72+
73+
if (i_dt.size() != 1 || i_dt[0] != dnnl_f32) {
74+
BENCHDNN_PRINT(0, "%s\n",
75+
"ERROR: `dt` and `cfg` knobs are incompatible with each "
76+
"other. Specify only one of them at a time.");
77+
return FAIL;
78+
}
79+
}
80+
81+
static constexpr int n_inputs = 2;
82+
for (const auto &i_dt : s.dt) {
83+
if (i_dt.size() != 1 && i_dt.size() != n_inputs) {
84+
BENCHDNN_PRINT(0, "%s%d.\n",
85+
"ERROR: `dt` option expects either a single input or two "
86+
"inputs in SRC, DST order. Current size is: ",
87+
static_cast<int>(i_dt.size()));
88+
return FAIL;
89+
}
90+
}
91+
92+
return OK;
93+
}
94+
6295
int bench(int argc, char **argv) {
6396
driver_name = "pool";
6497
using namespace parser;
@@ -69,6 +102,7 @@ int bench(int argc, char **argv) {
69102
const bool parsed_options = parse_bench_settings(argv[0])
70103
|| parse_batch(bench, argv[0])
71104
|| parse_dir(s.dir, def.dir, argv[0])
105+
|| parse_multi_dt(s.dt, def.dt, argv[0], "dt")
72106
|| parse_cfg(s.cfg, def.cfg, str2cfg, argv[0])
73107
|| parse_tag(s.tag, def.tag, argv[0])
74108
|| parse_alg(s.alg, def.alg, str2alg, argv[0])
@@ -86,6 +120,8 @@ int bench(int argc, char **argv) {
86120
catch_unknown_options(argv[0]);
87121

88122
SAFE(str2desc(&s.desc, argv[0]), CRIT);
123+
124+
SAFE(verify_input(s), WARN);
89125
check_correctness(s, task_executor);
90126
}
91127
}

tests/benchdnn/pool/cfg.cpp

+82-71
Original file line numberDiff line numberDiff line change
@@ -28,54 +28,49 @@
2828

2929
namespace pool {
3030

31-
/* cfgs definition
32-
* arrays: SRC, UNUSED, UNUSED, DST
33-
* params: {data_type, min, max, f_min, f_max, eps}
34-
*/
31+
cfg_t::cfg_t(const prb_t *prb, const std::vector<data_kind_t> &kinds) {
32+
output_data_kind_ = (prb->dir & FLAG_FWD) ? DST : SRC;
33+
for (const auto kind : kinds) {
34+
auto orig_data_type = prb->get_dt(kind);
35+
auto data_type = deduce_cfg_data_type(orig_data_type, prb->attr, kind);
36+
cfg_entry_.emplace(kind,
37+
cfg_entry_t {
38+
kind, orig_data_type, data_type, get_cfg_map(kind)});
39+
}
3540

36-
// though integers are expected, eps is needed to cover division error
37-
const dt_conf_t conf_entry_f32
38-
= {dnnl_f32, -FLT_MAX, FLT_MAX, -2048, 2048, 5e-7};
39-
const dt_conf_t conf_entry_s32 = {dnnl_s32, INT_MIN, INT_MAX, -2048, 2048, 0.};
40-
const dt_conf_t conf_entry_s8
41-
= {dnnl_s8, INT8_MIN, INT8_MAX, INT8_MIN, INT8_MAX, 0.};
42-
const dt_conf_t conf_entry_u8 = {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0.};
41+
// Keep values for average algorithms positive to prevent cancellation err.
42+
if (prb->alg != alg_t::max) {
43+
set_range_min(SRC, 0);
44+
set_range_min(DST, 0);
45+
}
4346

44-
const float16_t flt16_max = dnnl::impl::nstl::numeric_limits<float16_t>::max();
45-
const dt_conf_t conf_entry_f16
46-
= {dnnl_f16, -flt16_max, flt16_max, -32, 32, 2e-2};
47-
48-
#define BFLT16_MAX 3.38953138925153547590470800371487866880e+38F
49-
/* Although integers are expected, eps is needed to cover
50-
* for the division error */
51-
const dt_conf_t conf_entry_bf16
52-
= {dnnl_bf16, -BFLT16_MAX, BFLT16_MAX, -32, 32, 5e-2};
53-
#undef BFLT16_MAX
47+
BENCHDNN_PRINT(6, "%s SRC_%s=[%d;%d]\n", "[FILL_CFG]",
48+
dt2str(this->get_dt(SRC)), get_range_min(SRC), get_range_max(SRC));
49+
}
5450

55-
// Configurations with same SRC and DST datatypes
56-
const _dt_conf_t conf_f32 = {conf_entry_f32, {}, {}, conf_entry_f32};
57-
const _dt_conf_t conf_f64 = {conf_entry_f32, {}, {}, conf_entry_f32};
58-
const _dt_conf_t conf_s32 = {conf_entry_s32, {}, {}, conf_entry_s32};
59-
const _dt_conf_t conf_f16 = {conf_entry_f16, {}, {}, conf_entry_f16};
60-
const _dt_conf_t conf_bf16 = {conf_entry_bf16, {}, {}, conf_entry_bf16};
61-
const _dt_conf_t conf_s8 = {conf_entry_s8, {}, {}, conf_entry_s8};
62-
const _dt_conf_t conf_u8 = {conf_entry_u8, {}, {}, conf_entry_u8};
51+
cfg_t::cfg_entry_t::cfg_map_t cfg_t::get_cfg_map(data_kind_t kind) const {
52+
static const cfg_t::cfg_entry_t::cfg_map_t cfg_map = {
53+
{{dnnl_f64}, {-2048, 2048}},
54+
{{dnnl_f32}, {-2048, 2048}},
55+
{{dnnl_s32}, {-2048, 2048}},
56+
{{dnnl_bf16}, {-32, 32}},
57+
{{dnnl_f16}, {-32, 32}},
58+
{{dnnl_s8}, {INT8_MIN, INT8_MAX}},
59+
{{dnnl_u8}, {0, UINT8_MAX}},
60+
};
6361

64-
// Configurations with different SRC and DST datatypes
65-
const _dt_conf_t conf_s8u8 {conf_entry_s8, {}, {}, conf_entry_u8};
66-
const _dt_conf_t conf_u8s8 {conf_entry_u8, {}, {}, conf_entry_s8};
67-
const _dt_conf_t conf_s8f32 {conf_entry_s8, {}, {}, conf_entry_f32};
68-
const _dt_conf_t conf_f32s8 {conf_entry_f32, {}, {}, conf_entry_s8};
69-
const _dt_conf_t conf_u8f32 {conf_entry_u8, {}, {}, conf_entry_f32};
70-
const _dt_conf_t conf_f32u8 {conf_entry_f32, {}, {}, conf_entry_u8};
71-
const _dt_conf_t conf_s8f16 {conf_entry_s8, {}, {}, conf_entry_f16};
72-
const _dt_conf_t conf_f16s8 {conf_entry_f16, {}, {}, conf_entry_s8};
73-
const _dt_conf_t conf_u8f16 {conf_entry_u8, {}, {}, conf_entry_f16};
74-
const _dt_conf_t conf_f16u8 {conf_entry_f16, {}, {}, conf_entry_u8};
62+
switch (kind) {
63+
case SRC: return cfg_map;
64+
case DST: return cfg_map;
65+
default: assert(!"unsupported data kind"); break;
66+
}
67+
static cfg_t::cfg_entry_t::cfg_map_t dummy;
68+
return dummy;
69+
}
7570

76-
const dt_conf_t *str2cfg(const char *str) {
71+
std::string str2cfg(const char *str) {
7772
#define CASE(cfg) \
78-
if (!strcasecmp(STRINGIFY(cfg), str)) return CONCAT2(conf_, cfg)
73+
if (!strcasecmp(STRINGIFY(cfg), str)) return str
7974
CASE(f32);
8075
CASE(f64);
8176
CASE(s32);
@@ -93,38 +88,54 @@ const dt_conf_t *str2cfg(const char *str) {
9388
CASE(f16s8);
9489
CASE(u8f16);
9590
CASE(f16u8);
96-
9791
#undef CASE
98-
[]() {
99-
SAFE(FAIL, CRIT);
100-
return 0;
101-
}();
102-
return (const dt_conf_t *)1;
92+
BENCHDNN_PRINT(0, "Config name \'%s\' is not supported.\n", str);
93+
SAFE_V(CRIT);
94+
return std::string();
10395
}
10496

105-
std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg) {
106-
#define CASE(_cfg) \
107-
if (cfg == CONCAT2(conf_, _cfg)) return s << STRINGIFY(_cfg)
108-
CASE(f32);
109-
CASE(f64);
110-
CASE(s32);
111-
CASE(f16);
112-
CASE(bf16);
113-
CASE(s8);
114-
CASE(u8);
115-
CASE(s8u8);
116-
CASE(u8s8);
117-
CASE(s8f32);
118-
CASE(f32s8);
119-
CASE(u8f32);
120-
CASE(f32u8);
121-
CASE(s8f16);
122-
CASE(f16s8);
123-
CASE(u8f16);
124-
CASE(f16u8);
125-
#undef CASE
126-
SAFE_V(FAIL);
127-
return s;
97+
int handle_legacy_cfg(
98+
std::vector<dnnl_data_type_t> &dt, const std::string &cfg) {
99+
if (cfg == "f32")
100+
dt = {dnnl_f32};
101+
else if (cfg == "f64")
102+
dt = {dnnl_f64};
103+
else if (cfg == "s32")
104+
dt = {dnnl_s32};
105+
else if (cfg == "f16")
106+
dt = {dnnl_f16};
107+
else if (cfg == "bf16")
108+
dt = {dnnl_bf16};
109+
else if (cfg == "s8")
110+
dt = {dnnl_s8};
111+
else if (cfg == "u8")
112+
dt = {dnnl_u8};
113+
else if (cfg == "u8s8")
114+
dt = {dnnl_u8, dnnl_s8};
115+
else if (cfg == "s8u8")
116+
dt = {dnnl_s8, dnnl_u8};
117+
else if (cfg == "u8f32")
118+
dt = {dnnl_u8, dnnl_f32};
119+
else if (cfg == "f32u8")
120+
dt = {dnnl_f32, dnnl_u8};
121+
else if (cfg == "s8f32")
122+
dt = {dnnl_s8, dnnl_f32};
123+
else if (cfg == "f32s8")
124+
dt = {dnnl_f32, dnnl_s8};
125+
else if (cfg == "u8f16")
126+
dt = {dnnl_u8, dnnl_f16};
127+
else if (cfg == "f16u8")
128+
dt = {dnnl_f16, dnnl_u8};
129+
else if (cfg == "s8f16")
130+
dt = {dnnl_s8, dnnl_f16};
131+
else if (cfg == "f16s8")
132+
dt = {dnnl_f16, dnnl_s8};
133+
else {
134+
BENCHDNN_PRINT(0, "Error: Config name \'%s\' is not supported.\n",
135+
cfg.c_str());
136+
return FAIL;
137+
}
138+
return OK;
128139
}
129140

130141
} // namespace pool

0 commit comments

Comments
 (0)