28
28
29
29
namespace pool {
30
30
31
- /* cfgs definition
32
- * arrays: SRC, UNUSED, UNUSED, DST
33
- * params: {data_type, min, max, f_min, f_max, eps}
34
- */
31
+ cfg_t::cfg_t (const prb_t *prb, const std::vector<data_kind_t > &kinds) {
32
+ output_data_kind_ = (prb->dir & FLAG_FWD) ? DST : SRC;
33
+ for (const auto kind : kinds) {
34
+ auto orig_data_type = prb->get_dt (kind);
35
+ auto data_type = deduce_cfg_data_type (orig_data_type, prb->attr , kind);
36
+ cfg_entry_.emplace (kind,
37
+ cfg_entry_t {
38
+ kind, orig_data_type, data_type, get_cfg_map (kind)});
39
+ }
35
40
36
- // though integers are expected, eps is needed to cover division error
37
- const dt_conf_t conf_entry_f32
38
- = {dnnl_f32, -FLT_MAX, FLT_MAX, -2048 , 2048 , 5e-7 };
39
- const dt_conf_t conf_entry_s32 = {dnnl_s32, INT_MIN, INT_MAX, -2048 , 2048 , 0 .};
40
- const dt_conf_t conf_entry_s8
41
- = {dnnl_s8, INT8_MIN, INT8_MAX, INT8_MIN, INT8_MAX, 0 .};
42
- const dt_conf_t conf_entry_u8 = {dnnl_u8, 0 , UINT8_MAX, 0 , UINT8_MAX, 0 .};
41
+ // Keep values for average algorithms positive to prevent cancellation err.
42
+ if (prb->alg != alg_t ::max) {
43
+ set_range_min (SRC, 0 );
44
+ set_range_min (DST, 0 );
45
+ }
43
46
44
- const float16_t flt16_max = dnnl::impl::nstl::numeric_limits<float16_t >::max();
45
- const dt_conf_t conf_entry_f16
46
- = {dnnl_f16, -flt16_max, flt16_max, -32 , 32 , 2e-2 };
47
-
48
- #define BFLT16_MAX 3 .38953138925153547590470800371487866880e+38F
49
- /* Although integers are expected, eps is needed to cover
50
- * for the division error */
51
- const dt_conf_t conf_entry_bf16
52
- = {dnnl_bf16, -BFLT16_MAX, BFLT16_MAX, -32 , 32 , 5e-2 };
53
- #undef BFLT16_MAX
47
+ BENCHDNN_PRINT (6 , " %s SRC_%s=[%d;%d]\n " , " [FILL_CFG]" ,
48
+ dt2str (this ->get_dt (SRC)), get_range_min (SRC), get_range_max (SRC));
49
+ }
54
50
55
- // Configurations with same SRC and DST datatypes
56
- const _dt_conf_t conf_f32 = {conf_entry_f32, {}, {}, conf_entry_f32};
57
- const _dt_conf_t conf_f64 = {conf_entry_f32, {}, {}, conf_entry_f32};
58
- const _dt_conf_t conf_s32 = {conf_entry_s32, {}, {}, conf_entry_s32};
59
- const _dt_conf_t conf_f16 = {conf_entry_f16, {}, {}, conf_entry_f16};
60
- const _dt_conf_t conf_bf16 = {conf_entry_bf16, {}, {}, conf_entry_bf16};
61
- const _dt_conf_t conf_s8 = {conf_entry_s8, {}, {}, conf_entry_s8};
62
- const _dt_conf_t conf_u8 = {conf_entry_u8, {}, {}, conf_entry_u8};
51
+ cfg_t ::cfg_entry_t ::cfg_map_t cfg_t::get_cfg_map (data_kind_t kind) const {
52
+ static const cfg_t ::cfg_entry_t ::cfg_map_t cfg_map = {
53
+ {{dnnl_f64}, {-2048 , 2048 }},
54
+ {{dnnl_f32}, {-2048 , 2048 }},
55
+ {{dnnl_s32}, {-2048 , 2048 }},
56
+ {{dnnl_bf16}, {-32 , 32 }},
57
+ {{dnnl_f16}, {-32 , 32 }},
58
+ {{dnnl_s8}, {INT8_MIN, INT8_MAX}},
59
+ {{dnnl_u8}, {0 , UINT8_MAX}},
60
+ };
63
61
64
- // Configurations with different SRC and DST datatypes
65
- const _dt_conf_t conf_s8u8 {conf_entry_s8, {}, {}, conf_entry_u8};
66
- const _dt_conf_t conf_u8s8 {conf_entry_u8, {}, {}, conf_entry_s8};
67
- const _dt_conf_t conf_s8f32 {conf_entry_s8, {}, {}, conf_entry_f32};
68
- const _dt_conf_t conf_f32s8 {conf_entry_f32, {}, {}, conf_entry_s8};
69
- const _dt_conf_t conf_u8f32 {conf_entry_u8, {}, {}, conf_entry_f32};
70
- const _dt_conf_t conf_f32u8 {conf_entry_f32, {}, {}, conf_entry_u8};
71
- const _dt_conf_t conf_s8f16 {conf_entry_s8, {}, {}, conf_entry_f16};
72
- const _dt_conf_t conf_f16s8 {conf_entry_f16, {}, {}, conf_entry_s8};
73
- const _dt_conf_t conf_u8f16 {conf_entry_u8, {}, {}, conf_entry_f16};
74
- const _dt_conf_t conf_f16u8 {conf_entry_f16, {}, {}, conf_entry_u8};
62
+ switch (kind) {
63
+ case SRC: return cfg_map;
64
+ case DST: return cfg_map;
65
+ default : assert (!" unsupported data kind" ); break ;
66
+ }
67
+ static cfg_t ::cfg_entry_t ::cfg_map_t dummy;
68
+ return dummy;
69
+ }
75
70
76
- const dt_conf_t * str2cfg (const char *str) {
71
+ std::string str2cfg (const char *str) {
77
72
#define CASE (cfg ) \
78
- if (!strcasecmp (STRINGIFY (cfg), str)) return CONCAT2 (conf_, cfg)
73
+ if (!strcasecmp (STRINGIFY (cfg), str)) return str
79
74
CASE (f32);
80
75
CASE (f64);
81
76
CASE (s32);
@@ -93,38 +88,54 @@ const dt_conf_t *str2cfg(const char *str) {
93
88
CASE (f16s8);
94
89
CASE (u8f16);
95
90
CASE (f16u8);
96
-
97
91
#undef CASE
98
- []() {
99
- SAFE (FAIL, CRIT);
100
- return 0 ;
101
- }();
102
- return (const dt_conf_t *)1 ;
92
+ BENCHDNN_PRINT (0 , " Config name \' %s\' is not supported.\n " , str);
93
+ SAFE_V (CRIT);
94
+ return std::string ();
103
95
}
104
96
105
- std::ostream &operator <<(std::ostream &s, const dt_conf_t *cfg) {
106
- #define CASE (_cfg ) \
107
- if (cfg == CONCAT2 (conf_, _cfg)) return s << STRINGIFY (_cfg)
108
- CASE (f32);
109
- CASE (f64);
110
- CASE (s32);
111
- CASE (f16);
112
- CASE (bf16);
113
- CASE (s8);
114
- CASE (u8);
115
- CASE (s8u8);
116
- CASE (u8s8);
117
- CASE (s8f32);
118
- CASE (f32s8);
119
- CASE (u8f32);
120
- CASE (f32u8);
121
- CASE (s8f16);
122
- CASE (f16s8);
123
- CASE (u8f16);
124
- CASE (f16u8);
125
- #undef CASE
126
- SAFE_V (FAIL);
127
- return s;
97
+ int handle_legacy_cfg (
98
+ std::vector<dnnl_data_type_t > &dt, const std::string &cfg) {
99
+ if (cfg == " f32" )
100
+ dt = {dnnl_f32};
101
+ else if (cfg == " f64" )
102
+ dt = {dnnl_f64};
103
+ else if (cfg == " s32" )
104
+ dt = {dnnl_s32};
105
+ else if (cfg == " f16" )
106
+ dt = {dnnl_f16};
107
+ else if (cfg == " bf16" )
108
+ dt = {dnnl_bf16};
109
+ else if (cfg == " s8" )
110
+ dt = {dnnl_s8};
111
+ else if (cfg == " u8" )
112
+ dt = {dnnl_u8};
113
+ else if (cfg == " u8s8" )
114
+ dt = {dnnl_u8, dnnl_s8};
115
+ else if (cfg == " s8u8" )
116
+ dt = {dnnl_s8, dnnl_u8};
117
+ else if (cfg == " u8f32" )
118
+ dt = {dnnl_u8, dnnl_f32};
119
+ else if (cfg == " f32u8" )
120
+ dt = {dnnl_f32, dnnl_u8};
121
+ else if (cfg == " s8f32" )
122
+ dt = {dnnl_s8, dnnl_f32};
123
+ else if (cfg == " f32s8" )
124
+ dt = {dnnl_f32, dnnl_s8};
125
+ else if (cfg == " u8f16" )
126
+ dt = {dnnl_u8, dnnl_f16};
127
+ else if (cfg == " f16u8" )
128
+ dt = {dnnl_f16, dnnl_u8};
129
+ else if (cfg == " s8f16" )
130
+ dt = {dnnl_s8, dnnl_f16};
131
+ else if (cfg == " f16s8" )
132
+ dt = {dnnl_f16, dnnl_s8};
133
+ else {
134
+ BENCHDNN_PRINT (0 , " Error: Config name \' %s\' is not supported.\n " ,
135
+ cfg.c_str ());
136
+ return FAIL;
137
+ }
138
+ return OK;
128
139
}
129
140
130
141
} // namespace pool
0 commit comments