Skip to content

Commit 9d51a16

Browse files
[CPU] Enabled float (fp32/fp16/bf16) to nf4 precision conversion (#28829)
### Details: - This PR adds FP32/BF16/FP16 to NF4 support for Convert op in CPU Plugin - ### Tickets: - [CVS-153213](https://jira.devtools.intel.com/browse/CVS-153213)
1 parent c04d828 commit 9d51a16

File tree

5 files changed

+112
-24
lines changed

5 files changed

+112
-24
lines changed

src/plugins/intel_cpu/src/node.cpp

+2-20
Original file line numberDiff line numberDiff line change
@@ -1588,24 +1588,6 @@ ov::element::Type Node::getRuntimePrecision() const {
15881588
}
15891589

15901590
Node* Node::NodesFactory::create(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context) {
1591-
// getExceptionDescWithoutStatus removes redundant information from the exception message. For instance, the
1592-
// NotImplemented exception is generated in the form: full_path_to_src_file:line_number [ NOT_IMPLEMENTED ] reason.
1593-
// An example for gather node:
1594-
// /path-to-openVino-root/src/plugins/intel_cpu/nodes/gather.cpp:42 [ NOT_IMPLEMENTED ] Only opset7 Gather operation
1595-
// is supported The most important part of the message is the reason, so the lambda trims everything up to "]" Note
1596-
// that the op type and its friendly name will also be provided if we fail to create the node.
1597-
auto getExceptionDescWithoutStatus = [](const ov::Exception& ex) {
1598-
std::string desc = ex.what();
1599-
size_t pos = desc.find(']');
1600-
if (pos != std::string::npos) {
1601-
if (desc.size() == pos + 1) {
1602-
desc.erase(0, pos + 1);
1603-
} else {
1604-
desc.erase(0, pos + 2);
1605-
}
1606-
}
1607-
return desc;
1608-
};
16091591
Node* newNode = nullptr;
16101592
std::string errorMessage;
16111593
if (newNode == nullptr) {
@@ -1616,7 +1598,7 @@ Node* Node::NodesFactory::create(const std::shared_ptr<ov::Node>& op, const Grap
16161598
}
16171599
} catch (const ov::Exception& ex) {
16181600
if (dynamic_cast<const ov::NotImplemented*>(&ex) != nullptr) {
1619-
errorMessage += getExceptionDescWithoutStatus(ex);
1601+
errorMessage += ex.what();
16201602
} else {
16211603
throw;
16221604
}
@@ -1631,7 +1613,7 @@ Node* Node::NodesFactory::create(const std::shared_ptr<ov::Node>& op, const Grap
16311613
}
16321614
} catch (const ov::Exception& ex) {
16331615
if (dynamic_cast<const ov::NotImplemented*>(&ex) != nullptr) {
1634-
const auto currErrorMess = getExceptionDescWithoutStatus(ex);
1616+
const std::string currErrorMess = ex.what();
16351617
if (!currErrorMess.empty()) {
16361618
errorMessage += errorMessage.empty() ? currErrorMess : "\n" + currErrorMess;
16371619
}

src/plugins/intel_cpu/src/nodes/common/cpu_convert.cpp

+51
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,50 @@ struct ConvertFrom4BitPrecision<std::tuple<src_t, dst_t>> {
893893
}
894894
};
895895

896+
#define INTEL_CPU_CVT_TO_4BIT_LIST \
897+
INTEL_CPU_CVT(f32, nf4), INTEL_CPU_CVT(f16, nf4), INTEL_CPU_CVT(bf16, nf4)
898+
899+
struct ConvertTo4BitContext {
900+
ov::element::Type_t outType;
901+
const void* srcPtr;
902+
void* dstPtr;
903+
size_t size;
904+
bool converted;
905+
};
906+
907+
template <typename T>
908+
struct ConvertTo4BitPrecision;
909+
910+
template <typename src_t, typename dst_t>
911+
struct ConvertTo4BitPrecision<std::tuple<src_t, dst_t>> {
912+
void operator()(ConvertTo4BitContext& ctx) {
913+
auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t {
914+
uint8_t shift = high_half ? 4 : 0;
915+
return dst | (uint8_t) (val << shift);
916+
};
917+
918+
auto src = static_cast<const src_t*>(ctx.srcPtr);
919+
auto dst = static_cast<uint8_t*>(ctx.dstPtr);
920+
// each byte must be fully processed within same thread
921+
auto work_amount = ctx.size / 2;
922+
auto has_tail = ctx.size % work_amount != 0;
923+
if (ctx.outType == ov::element::nf4) {
924+
parallel_for(work_amount, [&](size_t ib) {
925+
size_t idx = ib*2;
926+
const auto val = insert_half_byte(0, ConvertNF4::quantize(static_cast<float>(src[idx])), false);
927+
dst[ib] = insert_half_byte(val, ConvertNF4::quantize(static_cast<float>(src[idx+1])), true);
928+
});
929+
930+
if (has_tail) {
931+
dst[work_amount] = insert_half_byte(0, ConvertNF4::quantize(static_cast<float>(src[2*work_amount])), false);
932+
}
933+
} else {
934+
OPENVINO_THROW("cpu_convert doesn't support output data type: ", ctx.outType, ". Not implemented.");
935+
}
936+
ctx.converted = true;
937+
}
938+
};
939+
896940
#define INTEL_CPU_CVT_FROM_BYTE_FP_LIST \
897941
INTEL_CPU_CVT(f8e8m0, f32), INTEL_CPU_CVT(f8e8m0, bf16), INTEL_CPU_CVT(f8e8m0, f16)
898942

@@ -1017,6 +1061,12 @@ void cpu_convert(const void* srcPtr,
10171061
if (!ctx.converted) {
10181062
OPENVINO_THROW("cpu_convert can't convert from: ", srcPrc, " precision to: ", dstPrc);
10191063
}
1064+
} else if (dstPrc.bitwidth() == 4u) {
1065+
ConvertTo4BitContext ctx{dstPrc, srcPtr, dstPtr, size, false};
1066+
OV_SWITCH(intel_cpu, ConvertTo4BitPrecision, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_TO_4BIT_LIST);
1067+
if (!ctx.converted) {
1068+
OPENVINO_THROW("cpu_convert can't convert from: ", srcPrc, " precision to: ", dstPrc);
1069+
}
10201070
} else if (srcPrc == ov::element::f8e8m0) {
10211071
ConvertFromByteFPContext ctx{srcPrc, srcPtr, dstPtr, size, false};
10221072
OV_SWITCH(intel_cpu,
@@ -1063,6 +1113,7 @@ bool is_supported_convert(ov::element::Type srcPrc, ov::element::Type dstPrc) {
10631113
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_BIN_LIST);
10641114
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_4BIT_LIST);
10651115
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_BYTE_FP_LIST);
1116+
OV_SWITCH(intel_cpu, isSupported, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_TO_4BIT_LIST);
10661117
return ctx.isSupported;
10671118
}
10681119

src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.cpp

+49-4
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ void ConvertCPULayerTest::SetUp() {
104104
#if defined(OPENVINO_ARCH_ARM64)
105105
if (inPrc == ov::element::u4 || inPrc == ov::element::i4 ||
106106
inPrc == ov::element::f8e4m3 || inPrc == ov::element::f8e5m2 ||
107-
outPrc == ov::element::f8e4m3 || outPrc == ov::element::f8e5m2) {
107+
outPrc == ov::element::f8e4m3 || outPrc == ov::element::f8e5m2 ||
108+
outPrc == ov::element::nf4) {
108109
primitive = "ref";
109110
} else if (shapes.first.is_static() &&
110111
inPrc != ov::element::bf16 && outPrc != ov::element::bf16 &&
@@ -151,8 +152,16 @@ void ConvertCPULayerTest::generate_inputs(const std::vector<ov::Shape>& targetIn
151152
const auto& funcInputs = function->inputs();
152153
for (size_t i = 0; i < funcInputs.size(); ++i) {
153154
const auto& funcInput = funcInputs[i];
154-
ov::Tensor tensor =
155-
ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i]);
155+
ov::Tensor tensor;
156+
if (outPrc == ov::element::nf4) {
157+
tensor = ov::test::utils::create_and_fill_tensor_real_distribution(funcInput.get_element_type(),
158+
targetInputStaticShapes[i],
159+
-1.f,
160+
1.f,
161+
1);
162+
} else {
163+
tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i]);
164+
}
156165
if (special_value != ov::test::SpecialValue::none) {
157166
if (inPrc == ov::element::f32) {
158167
modify_value<float>(tensor, special_value);
@@ -176,6 +185,40 @@ void ConvertCPULayerTest::validate_out_prc() const {
176185
FAIL() << "ConvertCPULayerTest supports only non boolean output prc";
177186
}
178187

188+
void ConvertCPULayerTest::validate() {
189+
if (outPrc == ov::element::nf4) {
190+
// Use custom bit-exact validation, because common tests infra doesn't support 4bits tensors comparision
191+
auto actualOutputs = get_plugin_outputs();
192+
auto expectedOutputs = calculate_refs();
193+
ASSERT_EQ(expectedOutputs.size(), actualOutputs.size());
194+
ASSERT_EQ(expectedOutputs.size(), 1);
195+
ASSERT_EQ(expectedOutputs[0].get_shape(), actualOutputs[0].get_shape());
196+
ASSERT_EQ(expectedOutputs[0].get_element_type(), ov::element::nf4);
197+
ASSERT_EQ(expectedOutputs[0].get_element_type(), actualOutputs[0].get_element_type());
198+
199+
auto expected_data = reinterpret_cast<const uint8_t*>(expectedOutputs[0].data());
200+
auto actual_data = reinterpret_cast<const uint8_t*>(actualOutputs[0].data());
201+
size_t byte_count = shape_size(expectedOutputs[0].get_shape()) / 2;
202+
bool has_tile = shape_size(expectedOutputs[0].get_shape()) % 2 != 0;
203+
for (size_t i = 0; i < byte_count; ++i) {
204+
uint8_t expected_value = expected_data[i];
205+
uint8_t actual_value = actual_data[i];
206+
ASSERT_EQ(expected_value, actual_value);
207+
}
208+
209+
// Convert operation doc doesn't specify behavior for odd amount of elements: should upper 4 bits of last byte be filled with zeros or not.
210+
// CPU Plugin fills these bits with zeros as it better fits optimized kernels which get NF4 inputs.
211+
// In general it is considered as UB, so skip the check for last 4 bits.
212+
if (has_tile) {
213+
ASSERT_EQ(expected_data[byte_count] & 0x0F, actual_data[byte_count] & 0x0F);
214+
}
215+
216+
return;
217+
}
218+
219+
SubgraphBaseTest::validate();
220+
}
221+
179222
void ConvertToBooleanCPULayerTest::validate_out_prc() const {
180223
if (outPrc != ov::element::boolean)
181224
FAIL() << "ConvertToBooleanCPULayerTest supports only boolean output prc";
@@ -274,7 +317,9 @@ const std::vector<InputShape>& inShapes_4D_dynamic() {
274317
{
275318
{2, 4, 4, 1},
276319
{2, 17, 5, 4},
277-
{1, 2, 3, 4}
320+
{1, 2, 3, 4},
321+
// odd number of elements
322+
{1, 3, 3, 3}
278323
}
279324
},
280325
{

src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/conversion.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class ConvertCPULayerTest : public testing::WithParamInterface<convertLayerTestP
2929
protected:
3030
void SetUp() override;
3131
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
32+
void validate() override;
3233
virtual void validate_out_prc() const;
3334

3435
ov::element::Type inPrc, outPrc;

src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/common/conversion.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ const std::vector<ov::element::Type> float_precisions = {
6464
ov::element::bf16,
6565
};
6666

67+
INSTANTIATE_TEST_SUITE_P(smoke_ConvertCPULayerTest_float_to_nf4, ConvertCPULayerTest,
68+
::testing::Combine(
69+
::testing::ValuesIn(inShapes_4D_dynamic()),
70+
::testing::ValuesIn(float_precisions),
71+
::testing::Values(ov::element::nf4),
72+
::testing::Values(ov::test::SpecialValue::none),
73+
::testing::Values(CPUSpecificParams({nchw}, {nchw}, {}, {"ref"}))),
74+
ConvertCPULayerTest::getTestCaseName);
75+
6776
const std::vector<ov::element::Type> f8_precisions = {
6877
ov::element::f8e4m3,
6978
ov::element::f8e5m2,

0 commit comments

Comments
 (0)