Skip to content

Commit 32dc361

Browse files
committed
api: add fp4_e3m0 support
1 parent ed96323 commit 32dc361

10 files changed

+132
-3
lines changed

include/oneapi/dnnl/dnnl.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,8 @@ struct memory : public handle<dnnl_memory_t> {
856856
enum class data_type {
857857
/// Undefined data type (used for empty memory descriptors).
858858
undef = dnnl_data_type_undef,
859+
/// 4-bit float data type with 3-bit exponent and 0 bit mantissa.
860+
f4_e3m0 = dnnl_f4_e3m0,
859861
/// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa.
860862
f4_e2m1 = dnnl_f4_e2m1,
861863
/// [MX-compliant 8-bit compliant scale data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 8-bit exponent.

include/oneapi/dnnl/dnnl_common_types.h

+2
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ typedef enum {
106106
dnnl_e8m0 = 13,
107107
/// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa.
108108
dnnl_f4_e2m1 = 14,
109+
/// 4-bit float data type with 3-bit exponent and 0 bit mantissa.
110+
dnnl_f4_e3m0 = 15,
109111

110112
/// Parameter to allow internal only data_types without undefined behavior.
111113
/// This parameter is chosen to be valid for so long as sizeof(int) >= 2.

src/common/c_types_map.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ const alg_kind_t eltwise_stochastic_round
153153
using data_type_t = dnnl_data_type_t;
154154
namespace data_type {
155155
const data_type_t undef = dnnl_data_type_undef;
156+
const data_type_t f4_e3m0 = dnnl_f4_e3m0;
156157
const data_type_t f4_e2m1 = dnnl_f4_e2m1;
157158
const data_type_t e8m0 = dnnl_e8m0;
158159
const data_type_t f8_e5m2 = dnnl_f8_e5m2;

src/common/dnnl_debug_autogenerated.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ const char *dnnl_dt2str(dnnl_data_type_t v) {
5959
if (v == dnnl_u4) return "u4";
6060
if (v == dnnl_e8m0) return "e8m0";
6161
if (v == dnnl_f4_e2m1) return "f4_e2m1";
62+
if (v == dnnl_f4_e3m0) return "f4_e3m0";
6263
if (v == dnnl_data_type_max) return "data_type_max";
6364
assert(!"unknown dt");
6465
return "unknown dt";

src/common/dnnl_traits.hpp

+8
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ struct typesize_traits {}; /* ::data_type_size -> f32 */
3838
template <primitive_kind_t>
3939
struct pkind_traits {}; /* ::desc_type, ::query_d */
4040

41+
template <>
42+
struct prec_traits<data_type::f4_e3m0> {
43+
typedef float4_e3m0_t type;
44+
};
4145
template <>
4246
struct prec_traits<data_type::f4_e2m1> {
4347
typedef float4_e2m1_t type;
@@ -95,6 +99,10 @@ struct prec_traits<data_type::boolean> {
9599
typedef bool type;
96100
};
97101

102+
template <>
103+
struct data_traits<float4_e3m0_t> {
104+
static constexpr data_type_t data_type = data_type::f4_e3m0;
105+
};
98106
template <>
99107
struct data_traits<float4_e2m1_t> {
100108
static constexpr data_type_t data_type = data_type::f4_e2m1;

src/common/float4.cpp

+68-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ uint8_t float2e2m1(float f) {
3232
// There is no NaN or infinity in e2m1, for now we just return zero
3333
// TODO: figure if there is a standard value to return
3434
uint32_t naninf_mask = 0x7f800000;
35-
if ((f_raw & naninf_mask) == naninf_mask) return 0x00000000;
35+
if ((f_raw & naninf_mask) == naninf_mask) return 0x00;
3636

3737
// we convert with naive closest value computation out of 8
3838
float e2m1_val_table[8] = {0.0f, .5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f};
@@ -91,5 +91,72 @@ float4_e2m1_t::operator float16_t() const {
9191
return e2m1_table[raw_bits_];
9292
}
9393

94+
uint8_t float2e3m0(float f) {
95+
uint32_t f_raw = float2int(f);
96+
uint32_t sign = f_raw & 0x80000000;
97+
98+
// There is no NaN or infinity in e3m0, we just return maxval
99+
uint32_t naninf_mask = 0x7f800000;
100+
if ((f_raw & naninf_mask) == naninf_mask) return 0x7;
101+
102+
// we convert with naive closest value computation out of 8
103+
float e3m0_val_table[8] = {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f};
104+
105+
float abs_f = int2float(f_raw ^ sign);
106+
107+
int idx = 0;
108+
float min_diff = ::fabsf(e3m0_val_table[idx] - abs_f);
109+
uint8_t raw_bits = idx;
110+
for (++idx; idx < 8; ++idx) {
111+
float diff = ::fabsf(e3m0_val_table[idx] - abs_f);
112+
if (diff < min_diff) {
113+
min_diff = diff;
114+
raw_bits = idx;
115+
}
116+
// Special case for midpoint, we round to even (so even index)
117+
if ((diff == min_diff) && !(idx & 1)) raw_bits = idx;
118+
}
119+
assert(raw_bits < 8);
120+
// reapply sign
121+
if (sign) raw_bits = raw_bits | 0x08;
122+
assert(raw_bits < 16);
123+
return raw_bits;
124+
}
125+
126+
float4_e3m0_t &float4_e3m0_t::operator=(bfloat16_t f) {
127+
float f32 = f;
128+
raw_bits_ = float2e3m0(f32);
129+
return *this;
130+
}
131+
132+
float4_e3m0_t &float4_e3m0_t::operator=(float16_t f) {
133+
float f32 = f;
134+
raw_bits_ = float2e3m0(f32);
135+
return *this;
136+
}
137+
138+
float4_e3m0_t &float4_e3m0_t::operator=(float f) {
139+
raw_bits_ = float2e3m0(f);
140+
return *this;
141+
}
142+
143+
float4_e3m0_t::operator float() const {
144+
// List of e3m0 values. The index of each value maps to its encoding.
145+
static const float e3m0_table[16]
146+
= {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f, -0.0f, -.25f,
147+
-.5f, -1.0f, -2.0f, -4.0f, -8.0f, -16.0f};
148+
assert(raw_bits_ < 16);
149+
return e3m0_table[raw_bits_];
150+
}
151+
152+
float4_e3m0_t::operator float16_t() const {
153+
// List of e3m0 values. The index of each value maps to its encoding.
154+
static const float16_t e3m0_table[16]
155+
= {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f, -0.0f, -.25f,
156+
-.5f, -1.0f, -2.0f, -4.0f, -8.0f, -16.0f};
157+
assert(raw_bits_ < 16);
158+
return e3m0_table[raw_bits_];
159+
}
160+
94161
} // namespace impl
95162
} // namespace dnnl

src/common/float4.hpp

+23
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,29 @@ struct float4_e2m1_t {
4949
};
5050
static_assert(sizeof(float4_e2m1_t) == 1, "float4_e2m1_t must be 1 byte");
5151

52+
struct float4_e3m0_t {
53+
uint8_t raw_bits_;
54+
float4_e3m0_t() = default;
55+
constexpr float4_e3m0_t(uint8_t r, bool = true) : raw_bits_(r) {}
56+
float4_e3m0_t(float f) { (*this) = f; }
57+
float4_e3m0_t(float16_t f) { (*this) = f; }
58+
float4_e3m0_t(bfloat16_t f) { (*this) = f; }
59+
60+
float4_e3m0_t DNNL_API &operator=(float f);
61+
float4_e3m0_t DNNL_API &operator=(float16_t f);
62+
float4_e3m0_t DNNL_API &operator=(bfloat16_t f);
63+
64+
DNNL_API operator float() const;
65+
DNNL_API operator float16_t() const;
66+
DNNL_API operator bfloat16_t() const;
67+
68+
float4_e3m0_t &operator+=(const float a) {
69+
(*this) = float {*this} + a;
70+
return *this;
71+
}
72+
};
73+
static_assert(sizeof(float4_e3m0_t) == 1, "float4_e3m0_t must be 1 byte");
74+
5275
} // namespace impl
5376
} // namespace dnnl
5477

src/common/nstl.hpp

+16
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,22 @@ struct numeric_limits<int8_t> : public std::numeric_limits<int8_t> {};
157157
template <>
158158
struct numeric_limits<uint8_t> : public std::numeric_limits<uint8_t> {};
159159

160+
template <>
161+
struct numeric_limits<float4_e3m0_t> {
162+
static constexpr float4_e3m0_t lowest() { return float4_e3m0_t(0xf, true); }
163+
// Min normal is equal to the value 1.0
164+
static constexpr float4_e3m0_t min() { return float4_e3m0_t(0x1, true); }
165+
// Max normal is equal to the value 6.0
166+
static constexpr float4_e3m0_t max() { return float4_e3m0_t(0x7, true); }
167+
168+
static constexpr int bias = 0x3;
169+
static constexpr int digits = 1; // 1 implicit bit
170+
171+
static constexpr float4_e3m0_t epsilon() {
172+
return float4_e3m0_t(0x3, true);
173+
}
174+
};
175+
160176
template <>
161177
struct numeric_limits<float4_e2m1_t> {
162178
static constexpr float4_e2m1_t lowest() { return float4_e2m1_t(0xf, true); }

src/common/type_helpers.hpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ namespace types {
9393
inline size_t data_type_size(data_type_t data_type) {
9494
using namespace data_type;
9595
switch ((int)data_type) {
96+
case f4_e3m0: return sizeof(prec_traits<f4_e3m0>::type);
9697
case f4_e2m1: return sizeof(prec_traits<f4_e2m1>::type);
9798
case e8m0: return sizeof(prec_traits<e8m0>::type);
9899
case f8_e5m2: return sizeof(prec_traits<f8_e5m2>::type);
@@ -139,6 +140,7 @@ inline T min_value(data_type_t data_type) {
139140
case x: \
140141
return static_cast<T>(nstl::numeric_limits<prec_traits<x>::type>::min())
141142
switch (data_type) {
143+
CASE(f4_e3m0);
142144
CASE(f4_e2m1);
143145
CASE(e8m0);
144146
CASE(f8_e5m2);
@@ -166,6 +168,7 @@ inline T max_value(data_type_t data_type) {
166168
case x: \
167169
return static_cast<T>(nstl::numeric_limits<prec_traits<x>::type>::max())
168170
switch (data_type) {
171+
CASE(f4_e3m0);
169172
CASE(f4_e2m1);
170173
CASE(e8m0);
171174
CASE(f8_e5m2);
@@ -195,6 +198,7 @@ inline float max_value(data_type_t data_type) {
195198
return static_cast<float>( \
196199
nstl::numeric_limits<prec_traits<x>::type>::max())
197200
switch (data_type) {
201+
CASE(f4_e3m0);
198202
CASE(f4_e2m1);
199203
CASE(e8m0);
200204
CASE(f8_e5m2);
@@ -233,6 +237,7 @@ inline T lowest_value(data_type_t data_type) {
233237
return static_cast<T>( \
234238
nstl::numeric_limits<prec_traits<x>::type>::lowest())
235239
switch (data_type) {
240+
CASE(f4_e3m0);
236241
CASE(f4_e2m1);
237242
CASE(e8m0);
238243
CASE(f8_e5m2);
@@ -261,6 +266,7 @@ inline T digits(data_type_t data_type) {
261266
return static_cast<T>( \
262267
nstl::numeric_limits<prec_traits<x>::type>::digits)
263268
switch (data_type) {
269+
CASE(f4_e3m0);
264270
CASE(f4_e2m1);
265271
CASE(e8m0);
266272
CASE(f8_e5m2);
@@ -419,6 +425,7 @@ inline data_type_t default_accum_data_type(
419425
// true
420426
if (one_of(src_dt, s8, u8, u4, s4) && (dst_dt != f32 || strict)) return s32;
421427

428+
if (one_of(f4_e3m0, src_dt, dst_dt)) return f32;
422429
if (one_of(f4_e2m1, src_dt, dst_dt)) return f32;
423430
if (one_of(f8_e5m2, src_dt, dst_dt)) return f32;
424431
if (one_of(f8_e4m3, src_dt, dst_dt)) return f32;
@@ -461,6 +468,7 @@ inline data_type_t default_accum_data_type(data_type_t src_dt,
461468
return f32;
462469
}
463470

471+
if (one_of(f4_e3m0, src_dt, wei_dt, dst_dt)) return f32;
464472
if (one_of(f4_e2m1, src_dt, wei_dt, dst_dt)) return f32;
465473
if (one_of(f8_e5m2, src_dt, wei_dt, dst_dt)) return f32;
466474
if (one_of(f8_e4m3, src_dt, wei_dt, dst_dt)) return f32;
@@ -1262,8 +1270,8 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims,
12621270
if (ndims == 0) return true;
12631271

12641272
bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS
1265-
&& utils::one_of(data_type, f4_e2m1, e8m0, f8_e5m2, f8_e4m3, f16,
1266-
bf16, f32, f64, s32, s8, u8, s4, u4);
1273+
&& utils::one_of(data_type, f4_e3m0, f4_e2m1, e8m0, f8_e5m2,
1274+
f8_e4m3, f16, bf16, f32, f64, s32, s8, u8, s4, u4);
12671275
if (!ok) return false;
12681276

12691277
bool has_runtime_dims = false;

tests/benchdnn/dnnl_debug_autogenerated.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ dnnl_data_type_t str2dt(const char *str) {
5050
CASE(u4);
5151
CASE(e8m0);
5252
CASE(f4_e2m1);
53+
CASE(f4_e3m0);
5354
CASE(data_type_max);
5455
#undef CASE
5556
if (!strcmp("undef", str) || !strcmp("dnnl_data_type_undef", str))

0 commit comments

Comments
 (0)