9
9
#include < limits>
10
10
11
11
#include " openvino/core/type/float_util.hpp"
12
+ #include " openvino/core/type/float16.hpp"
12
13
13
14
namespace ov {
14
15
@@ -48,33 +49,40 @@ constexpr uint8_t f8e4m3_e_max = 0x0f; // f8e4m3 exponent max value
48
49
constexpr uint8_t f8e4m3_m_size = 3 ; // f8e4m3 mantissa bits size
49
50
constexpr uint8_t f8e4m3_m_mask = 0x07 ; // f8e4m3 mantissa bit mask
50
51
51
- uint8_t f32_to_f8e4m3_bits (const float value) {
52
- constexpr uint32_t f32_s_mask = 0x80000000 ; // f32 sign bit mask
53
- constexpr uint32_t f32_e_mask = 0x7F800000 ; // f32 exponent bits mask
54
- constexpr uint32_t f32_e_bias = 127 ; // f32 exponent bias
55
- constexpr uint32_t f32_e_size = 8 ; // f32 exponent bits size
56
- constexpr uint32_t f32_m_mask = 0x007fffff ; // f32 mantissa bits mask
57
- constexpr uint32_t f32_m_size = 23 ; // f32 mantissa bits size
52
+ uint8_t f16_to_f8e4m3_bits (const float16 value) {
53
+ constexpr uint16_t f16_s_mask = 0x8000 ; // f16 sign bit mask
54
+ constexpr uint16_t f16_e_mask = 0x7C00 ; // f16 exponent bits mask
55
+ constexpr uint16_t f16_e_bias = 15 ; // f16 exponent bias
56
+ constexpr uint16_t f16_e_size = 5 ; // f16 exponent bits size
57
+ constexpr uint16_t f16_m_mask = 0x03ff ; // f16 mantissa bits mask
58
+ constexpr uint16_t f16_m_size = 10 ; // f16 mantissa bits size
58
59
59
- constexpr uint32_t f8_e_mask = f8e4m3_e_mask << three_bytes_shift; // f8 exponent bits mask (on u32)
60
- constexpr uint32_t f8_m_mask = f8e4m3_m_mask << three_bytes_shift; // f8 mantissa bits mask (on u32)
61
- constexpr uint32_t f8_m_hidden_one_mask = 0x08000000 ; // f8 mantissa hidden one bits mask (on u32)
60
+ constexpr uint8_t byte_shift = 8 ;
62
61
63
- constexpr uint32_t round_half = 0x01ffffff ; // value for half to even round for f8
64
- constexpr uint32_t round_norm = 0x007fffff ; // value for normal round for f8
65
- constexpr uint32_t round_even = 0x00800000 ; // value for half to even round for f8
66
- constexpr uint32_t round_odd = 0x01800000 ; // value for an non-half to even round for f8
62
+ constexpr uint16_t f8_e_mask = f8e4m3_e_mask << byte_shift; // f8 exponent bits mask (on u16)
63
+ constexpr uint16_t f8_m_mask = f8e4m3_m_mask << byte_shift; // f8 mantissa bits mask (on u16)
64
+ constexpr uint16_t f8_m_hidden_one_mask = 0x0800 ; // f8 mantissa hidden one bits mask (on u16)
67
65
68
- const auto input = util::f32_to_u32_bits (value);
69
- auto f8_bits = static_cast <uint8_t >((input & f32_s_mask) >> three_bytes_shift);
66
+ constexpr uint16_t round_half = 0x01ff ; // value for half to even round for f8
67
+ constexpr uint16_t round_norm = 0x007f ; // value for normal round for f8
68
+ constexpr uint16_t round_even = 0x0080 ; // value for half to even round for f8
69
+ constexpr uint16_t round_odd = 0x0180 ; // value for an non-half to even round for f8
70
70
71
- uint32_t f32_e_field = input & f32_e_mask;
71
+ // f8 exponent min value for subnormal
72
+ // For f8_e less than -10, the hidden 1 is shifted beyond rounding bit.
73
+ // So the 3 bits in mantissa and rounding bit are all 0, the f8 value is always 0.
74
+ constexpr int16_t f8_e_subnormal_min = -10 ;
72
75
73
- if (f32_e_field == f32_e_mask) {
76
+ const uint16_t input = value.to_bits ();
77
+ uint8_t f8_bits = static_cast <uint8_t >((input & f16_s_mask) >> byte_shift);
78
+
79
+ uint16_t f16_e_field = input & f16_e_mask;
80
+
81
+ if (f16_e_field == f16_e_mask) {
74
82
f8_bits |= (f8e4m3_e_mask | f8e4m3_m_mask);
75
- } else if (f32_e_field != 0 ) {
76
- int32_t f8_biased_exp = (f32_e_field >> f32_m_size ) - (f32_e_bias - f8e4m3_e_bias);
77
- uint32_t fractional = (input & f32_m_mask ) << (f32_e_size - f8e4m3_e_size);
83
+ } else if (f16_e_field != 0 ) {
84
+ int16_t f8_biased_exp = (f16_e_field >> f16_m_size ) - (f16_e_bias - f8e4m3_e_bias);
85
+ uint16_t fractional = (input & f16_m_mask ) << (f16_e_size - f8e4m3_e_size);
78
86
79
87
// for normalized values round apply rounding change f8 fractional and biased exponent
80
88
if ((fractional & round_half) == round_odd || (fractional & round_norm) != 0 ) {
@@ -91,22 +99,24 @@ uint8_t f32_to_f8e4m3_bits(const float value) {
91
99
// Use NAN as this type has no infinity
92
100
f8_bits |= (f8e4m3_e_mask | f8e4m3_m_mask);
93
101
} else if (f8_biased_exp > 0 ) {
94
- f8_bits |= (f8_biased_exp << f8e4m3_m_size) | (fractional >> three_bytes_shift );
102
+ f8_bits |= (f8_biased_exp << f8e4m3_m_size) | (fractional >> byte_shift );
95
103
} else {
96
104
// Restore the hidden 1 in f8 mantissa for subnormal calculation
97
- fractional = f8_m_hidden_one_mask | (input & f32_m_mask) << (f32_e_size - f8e4m3_e_size);
98
- // Will any bits be shifted off?
99
- int32_t shift = f8_biased_exp < -(f8e4m3_e_max) ? 0 : (1U << (1 - f8_biased_exp));
100
- uint32_t sticky = (fractional & (shift - 1 )) ? 1 : 0 ;
101
-
102
- fractional = ((1 + f8_biased_exp) > f8e4m3_e_max) ? 0 : fractional >> (1 - f8_biased_exp);
103
- fractional |= sticky;
105
+ fractional = f8_m_hidden_one_mask | (input & f16_m_mask) << (f16_e_size - f8e4m3_e_size);
106
+ int16_t f8_exp = f8_biased_exp - f8e4m3_e_bias;
107
+ int16_t shift = 1 - f8_exp;
108
+ int16_t sticky_mask = f8_exp < f8_e_subnormal_min ? 0 : ((1 << shift) - 1 );
109
+ uint16_t sticky = (fractional & sticky_mask) ? 1 : 0 ;
110
+
111
+ // Subnormal mantissa has less significant bits for smaller exponent
112
+ fractional = f8_exp < f8_e_subnormal_min ? 0 : fractional >> (1 - f8_biased_exp);
104
113
// apply rounding
105
- if (((fractional & round_half) == round_odd) || ((fractional & round_norm) != 0 )) {
114
+ if (((fractional & round_half) == round_odd && sticky == 0 ) || (fractional & round_norm) != 0 ||
115
+ sticky != 0 ) {
106
116
fractional += round_even;
107
117
}
108
118
109
- f8_bits |= fractional >> three_bytes_shift ;
119
+ f8_bits |= fractional >> byte_shift ;
110
120
}
111
121
}
112
122
@@ -118,7 +128,7 @@ float8_e4m3::float8_e4m3(const uint32_t sign, const uint32_t biased_exponent, co
118
128
: m_value(((sign & 0x01U ) << (f8e4m3_e_size + f8e4m3_m_size)) |
119
129
(biased_exponent & (f8e4m3_e_mask >> f8e4m3_m_size)) << f8e4m3_m_size | (fraction & f8e4m3_m_mask)) {}
120
130
121
- float8_e4m3::float8_e4m3 (const float value) : m_value{f32_to_f8e4m3_bits ( value)} {}
131
+ float8_e4m3::float8_e4m3 (const float value) : m_value{f16_to_f8e4m3_bits ( static_cast <float16>( value) )} {}
122
132
123
133
float8_e4m3::operator float () const {
124
134
auto f32_bits = util::f32_to_u32_bits (f8_to_float_lut[m_value & (f8e4m3_e_mask | f8e4m3_m_mask)]);
0 commit comments