Skip to content

Commit 3efb13a

Browse files
authored
Implement 2-step conversion from fp32 to fp8 (#28501)
### Details: - *Implement reference conversion from fp16 to f8e4m3, and apply 2-step conversion, i.e., apply fp32->fp16, then fp16->fp8 for conversions from fp32 to fp8.* ### Tickets: - *[CVS-160375](https://jira.devtools.intel.com/browse/CVS-160375)*
1 parent ad8e793 commit 3efb13a

File tree

4 files changed

+65
-47
lines changed

4 files changed

+65
-47
lines changed

src/core/src/type/float8_e4m3.cpp

+42-32
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <limits>
1010

1111
#include "openvino/core/type/float_util.hpp"
12+
#include "openvino/core/type/float16.hpp"
1213

1314
namespace ov {
1415

@@ -48,33 +49,40 @@ constexpr uint8_t f8e4m3_e_max = 0x0f; // f8e4m3 exponent max value
4849
constexpr uint8_t f8e4m3_m_size = 3; // f8e4m3 mantissa bits size
4950
constexpr uint8_t f8e4m3_m_mask = 0x07; // f8e4m3 mantissa bit mask
5051

51-
uint8_t f32_to_f8e4m3_bits(const float value) {
52-
constexpr uint32_t f32_s_mask = 0x80000000; // f32 sign bit mask
53-
constexpr uint32_t f32_e_mask = 0x7F800000; // f32 exponent bits mask
54-
constexpr uint32_t f32_e_bias = 127; // f32 exponent bias
55-
constexpr uint32_t f32_e_size = 8; // f32 exponent bits size
56-
constexpr uint32_t f32_m_mask = 0x007fffff; // f32 mantissa bits mask
57-
constexpr uint32_t f32_m_size = 23; // f32 mantissa bits size
52+
uint8_t f16_to_f8e4m3_bits(const float16 value) {
53+
constexpr uint16_t f16_s_mask = 0x8000; // f16 sign bit mask
54+
constexpr uint16_t f16_e_mask = 0x7C00; // f16 exponent bits mask
55+
constexpr uint16_t f16_e_bias = 15; // f16 exponent bias
56+
constexpr uint16_t f16_e_size = 5; // f16 exponent bits size
57+
constexpr uint16_t f16_m_mask = 0x03ff; // f16 mantissa bits mask
58+
constexpr uint16_t f16_m_size = 10; // f16 mantissa bits size
5859

59-
constexpr uint32_t f8_e_mask = f8e4m3_e_mask << three_bytes_shift; // f8 exponent bits mask (on u32)
60-
constexpr uint32_t f8_m_mask = f8e4m3_m_mask << three_bytes_shift; // f8 mantissa bits mask (on u32)
61-
constexpr uint32_t f8_m_hidden_one_mask = 0x08000000; // f8 mantissa hidden one bits mask (on u32)
60+
constexpr uint8_t byte_shift = 8;
6261

63-
constexpr uint32_t round_half = 0x01ffffff; // value for half to even round for f8
64-
constexpr uint32_t round_norm = 0x007fffff; // value for normal round for f8
65-
constexpr uint32_t round_even = 0x00800000; // value for half to even round for f8
66-
constexpr uint32_t round_odd = 0x01800000; // value for an non-half to even round for f8
62+
constexpr uint16_t f8_e_mask = f8e4m3_e_mask << byte_shift; // f8 exponent bits mask (on u16)
63+
constexpr uint16_t f8_m_mask = f8e4m3_m_mask << byte_shift; // f8 mantissa bits mask (on u16)
64+
constexpr uint16_t f8_m_hidden_one_mask = 0x0800; // f8 mantissa hidden one bits mask (on u16)
6765

68-
const auto input = util::f32_to_u32_bits(value);
69-
auto f8_bits = static_cast<uint8_t>((input & f32_s_mask) >> three_bytes_shift);
66+
constexpr uint16_t round_half = 0x01ff; // value for half to even round for f8
67+
constexpr uint16_t round_norm = 0x007f; // value for normal round for f8
68+
constexpr uint16_t round_even = 0x0080; // value for half to even round for f8
69+
constexpr uint16_t round_odd = 0x0180; // value for an non-half to even round for f8
7070

71-
uint32_t f32_e_field = input & f32_e_mask;
71+
// f8 exponent min value for subnormal
72+
// For f8_e less than -10, the hidden 1 is shifted beyond rounding bit.
73+
// So the 3 bits in mantissa and rounding bit are all 0, the f8 value is always 0.
74+
constexpr int16_t f8_e_subnormal_min = -10;
7275

73-
if (f32_e_field == f32_e_mask) {
76+
const uint16_t input = value.to_bits();
77+
uint8_t f8_bits = static_cast<uint8_t>((input & f16_s_mask) >> byte_shift);
78+
79+
uint16_t f16_e_field = input & f16_e_mask;
80+
81+
if (f16_e_field == f16_e_mask) {
7482
f8_bits |= (f8e4m3_e_mask | f8e4m3_m_mask);
75-
} else if (f32_e_field != 0) {
76-
int32_t f8_biased_exp = (f32_e_field >> f32_m_size) - (f32_e_bias - f8e4m3_e_bias);
77-
uint32_t fractional = (input & f32_m_mask) << (f32_e_size - f8e4m3_e_size);
83+
} else if (f16_e_field != 0) {
84+
int16_t f8_biased_exp = (f16_e_field >> f16_m_size) - (f16_e_bias - f8e4m3_e_bias);
85+
uint16_t fractional = (input & f16_m_mask) << (f16_e_size - f8e4m3_e_size);
7886

7987
// for normalized values round apply rounding change f8 fractional and biased exponent
8088
if ((fractional & round_half) == round_odd || (fractional & round_norm) != 0) {
@@ -91,22 +99,24 @@ uint8_t f32_to_f8e4m3_bits(const float value) {
9199
// Use NAN as this type has no infinity
92100
f8_bits |= (f8e4m3_e_mask | f8e4m3_m_mask);
93101
} else if (f8_biased_exp > 0) {
94-
f8_bits |= (f8_biased_exp << f8e4m3_m_size) | (fractional >> three_bytes_shift);
102+
f8_bits |= (f8_biased_exp << f8e4m3_m_size) | (fractional >> byte_shift);
95103
} else {
96104
// Restore the hidden 1 in f8 mantissa for subnormal calculation
97-
fractional = f8_m_hidden_one_mask | (input & f32_m_mask) << (f32_e_size - f8e4m3_e_size);
98-
// Will any bits be shifted off?
99-
int32_t shift = f8_biased_exp < -(f8e4m3_e_max) ? 0 : (1U << (1 - f8_biased_exp));
100-
uint32_t sticky = (fractional & (shift - 1)) ? 1 : 0;
101-
102-
fractional = ((1 + f8_biased_exp) > f8e4m3_e_max) ? 0 : fractional >> (1 - f8_biased_exp);
103-
fractional |= sticky;
105+
fractional = f8_m_hidden_one_mask | (input & f16_m_mask) << (f16_e_size - f8e4m3_e_size);
106+
int16_t f8_exp = f8_biased_exp - f8e4m3_e_bias;
107+
int16_t shift = 1 - f8_exp;
108+
int16_t sticky_mask = f8_exp < f8_e_subnormal_min ? 0 : ((1 << shift) - 1);
109+
uint16_t sticky = (fractional & sticky_mask) ? 1 : 0;
110+
111+
// Subnormal mantissa has less significant bits for smaller exponent
112+
fractional = f8_exp < f8_e_subnormal_min ? 0 : fractional >> (1 - f8_biased_exp);
104113
// apply rounding
105-
if (((fractional & round_half) == round_odd) || ((fractional & round_norm) != 0)) {
114+
if (((fractional & round_half) == round_odd && sticky == 0) || (fractional & round_norm) != 0 ||
115+
sticky != 0) {
106116
fractional += round_even;
107117
}
108118

109-
f8_bits |= fractional >> three_bytes_shift;
119+
f8_bits |= fractional >> byte_shift;
110120
}
111121
}
112122

@@ -118,7 +128,7 @@ float8_e4m3::float8_e4m3(const uint32_t sign, const uint32_t biased_exponent, co
118128
: m_value(((sign & 0x01U) << (f8e4m3_e_size + f8e4m3_m_size)) |
119129
(biased_exponent & (f8e4m3_e_mask >> f8e4m3_m_size)) << f8e4m3_m_size | (fraction & f8e4m3_m_mask)) {}
120130

121-
float8_e4m3::float8_e4m3(const float value) : m_value{f32_to_f8e4m3_bits(value)} {}
131+
float8_e4m3::float8_e4m3(const float value) : m_value{f16_to_f8e4m3_bits(static_cast<float16>(value))} {}
122132

123133
float8_e4m3::operator float() const {
124134
auto f32_bits = util::f32_to_u32_bits(f8_to_float_lut[m_value & (f8e4m3_e_mask | f8e4m3_m_mask)]);

src/core/tests/float8_e4m3.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,15 @@ TEST(F8E4M3Test, f32_gt_zero_le_f8_half_lowest_subnormal) {
162162
EXPECT_EQ(f8.to_bits(), 0x00);
163163
}
164164

165-
TEST(F8E4M3Test, f32_gt_zero_gt_f8_half_lowest_subnormal) {
165+
TEST(F8E4M3Test, f32_in_f16_format_le_zero_gt_f8_half_lowest_subnormal) {
166166
const auto f8 = ov::float8_e4m3(0.00097656273283064365387f);
167167

168+
EXPECT_EQ(f8.to_bits(), 0x00);
169+
}
170+
171+
TEST(F8E4M3Test, f32_in_f16_format_gt_zero_gt_f8_half_lowest_subnormal) {
172+
const auto f8 = ov::float8_e4m3(0.00097751617431640625f);
173+
168174
EXPECT_EQ(f8.to_bits(), 0x01);
169175
}
170176

src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.cpp

+16-12
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ void ConvertCPULayerTest::SetUp() {
148148
}
149149

150150
void ConvertCPULayerTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
151+
if (outPrc != ov::element::nf4 && special_value == ov::test::SpecialValue::none) {
152+
SubgraphBaseTest::generate_inputs(targetInputStaticShapes);
153+
return;
154+
}
155+
151156
inputs.clear();
152157
const auto& funcInputs = function->inputs();
153158
for (size_t i = 0; i < funcInputs.size(); ++i) {
@@ -162,18 +167,17 @@ void ConvertCPULayerTest::generate_inputs(const std::vector<ov::Shape>& targetIn
162167
} else {
163168
tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i]);
164169
}
165-
if (special_value != ov::test::SpecialValue::none) {
166-
if (inPrc == ov::element::f32) {
167-
modify_value<float>(tensor, special_value);
168-
} else if (inPrc == ov::element::f16) {
169-
modify_value<ov::float16>(tensor, special_value);
170-
} else if (inPrc == ov::element::bf16) {
171-
modify_value<ov::bfloat16>(tensor, special_value);
172-
} else if (inPrc == ov::element::f8e4m3) {
173-
modify_value<ov::float8_e4m3>(tensor, special_value);
174-
} else if (inPrc == ov::element::f8e5m2) {
175-
modify_value<ov::float8_e5m2>(tensor, special_value);
176-
}
170+
171+
if (inPrc == ov::element::f32) {
172+
modify_value<float>(tensor, special_value);
173+
} else if (inPrc == ov::element::f16) {
174+
modify_value<ov::float16>(tensor, special_value);
175+
} else if (inPrc == ov::element::bf16) {
176+
modify_value<ov::bfloat16>(tensor, special_value);
177+
} else if (inPrc == ov::element::f8e4m3) {
178+
modify_value<ov::float8_e4m3>(tensor, special_value);
179+
} else if (inPrc == ov::element::f8e5m2) {
180+
modify_value<ov::float8_e5m2>(tensor, special_value);
177181
}
178182

179183
inputs.insert({funcInput.get_node_shared_ptr(), tensor});

src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,6 @@ std::vector<std::string> disabledTestPatterns() {
173173
R"(.*smoke_TopK/TopKLayerTest.Inference.*_k=21_.*_sort=value_modelType=f16_trgDev=CPU.*)",
174174
// Issue: 121812
175175
R"(.*ConvertCPULayerTest.*outFmts=(nhwc|nChw8c|nChw16c).*)",
176-
// Issue: MFDNN-12917. The oneDNN emitter of conversion from fp32 to fp8 has rounding issue.
177-
R"(.*ConvertCPULayerTest.*(\[1.1.1080.1920\]|\(2.17.5.4\))_.*_inputPRC=f32_targetPRC=f8e4m3_.*)",
178176
// Need to generate sequence exactly in the i64 data type. Enable in scope of i64 enabling.
179177
R"(.*RandomUniformLayerTestCPU.*OutPrc=i64.*)",
180178
// Issue: 123815 (Tests are sensintive to available thread count on testing machines)

0 commit comments

Comments
 (0)