@@ -893,6 +893,50 @@ struct ConvertFrom4BitPrecision<std::tuple<src_t, dst_t>> {
893
893
}
894
894
};
895
895
896
+ #define INTEL_CPU_CVT_TO_4BIT_LIST \
897
+ INTEL_CPU_CVT (f32, nf4), INTEL_CPU_CVT(f16, nf4), INTEL_CPU_CVT(bf16, nf4)
898
+
899
+ struct ConvertTo4BitContext {
900
+ ov::element::Type_t outType;
901
+ const void * srcPtr;
902
+ void * dstPtr;
903
+ size_t size;
904
+ bool converted;
905
+ };
906
+
907
+ template <typename T>
908
+ struct ConvertTo4BitPrecision ;
909
+
910
+ template <typename src_t , typename dst_t >
911
+ struct ConvertTo4BitPrecision <std::tuple<src_t , dst_t >> {
912
+ void operator ()(ConvertTo4BitContext& ctx) {
913
+ auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t {
914
+ uint8_t shift = high_half ? 4 : 0 ;
915
+ return dst | (uint8_t ) (val << shift);
916
+ };
917
+
918
+ auto src = static_cast <const src_t *>(ctx.srcPtr );
919
+ auto dst = static_cast <uint8_t *>(ctx.dstPtr );
920
+ // each byte must be fully processed within same thread
921
+ auto work_amount = ctx.size / 2 ;
922
+ auto has_tail = ctx.size % work_amount != 0 ;
923
+ if (ctx.outType == ov::element::nf4) {
924
+ parallel_for (work_amount, [&](size_t ib) {
925
+ size_t idx = ib*2 ;
926
+ const auto val = insert_half_byte (0 , ConvertNF4::quantize (static_cast <float >(src[idx])), false );
927
+ dst[ib] = insert_half_byte (val, ConvertNF4::quantize (static_cast <float >(src[idx+1 ])), true );
928
+ });
929
+
930
+ if (has_tail) {
931
+ dst[work_amount] = insert_half_byte (0 , ConvertNF4::quantize (static_cast <float >(src[2 *work_amount])), false );
932
+ }
933
+ } else {
934
+ OPENVINO_THROW (" cpu_convert doesn't support output data type: " , ctx.outType , " . Not implemented." );
935
+ }
936
+ ctx.converted = true ;
937
+ }
938
+ };
939
+
896
940
#define INTEL_CPU_CVT_FROM_BYTE_FP_LIST \
897
941
INTEL_CPU_CVT (f8e8m0, f32), INTEL_CPU_CVT(f8e8m0, bf16), INTEL_CPU_CVT(f8e8m0, f16)
898
942
@@ -1017,6 +1061,12 @@ void cpu_convert(const void* srcPtr,
1017
1061
if (!ctx.converted ) {
1018
1062
OPENVINO_THROW (" cpu_convert can't convert from: " , srcPrc, " precision to: " , dstPrc);
1019
1063
}
1064
+ } else if (dstPrc.bitwidth () == 4u ) {
1065
+ ConvertTo4BitContext ctx{dstPrc, srcPtr, dstPtr, size, false };
1066
+ OV_SWITCH (intel_cpu, ConvertTo4BitPrecision, ctx, std::tie (srcPrc, dstPrc), INTEL_CPU_CVT_TO_4BIT_LIST);
1067
+ if (!ctx.converted ) {
1068
+ OPENVINO_THROW (" cpu_convert can't convert from: " , srcPrc, " precision to: " , dstPrc);
1069
+ }
1020
1070
} else if (srcPrc == ov::element::f8e8m0) {
1021
1071
ConvertFromByteFPContext ctx{srcPrc, srcPtr, dstPtr, size, false };
1022
1072
OV_SWITCH (intel_cpu,
@@ -1063,6 +1113,7 @@ bool is_supported_convert(ov::element::Type srcPrc, ov::element::Type dstPrc) {
1063
1113
OV_SWITCH (intel_cpu, isSupported, ctx, std::tie (srcPrc, dstPrc), INTEL_CPU_CVT_FROM_BIN_LIST);
1064
1114
OV_SWITCH (intel_cpu, isSupported, ctx, std::tie (srcPrc, dstPrc), INTEL_CPU_CVT_FROM_4BIT_LIST);
1065
1115
OV_SWITCH (intel_cpu, isSupported, ctx, std::tie (srcPrc, dstPrc), INTEL_CPU_CVT_FROM_BYTE_FP_LIST);
1116
+ OV_SWITCH (intel_cpu, isSupported, ctx, std::tie (srcPrc, dstPrc), INTEL_CPU_CVT_TO_4BIT_LIST);
1066
1117
return ctx.isSupported ;
1067
1118
}
1068
1119
0 commit comments