Skip to content

Commit ccfd8fd

Browse files
authored
[Transformations] Convert precision use optimized version for bf16 -> f16 (#25491)
### Details: - Use vectorized version of bf16 -> f16 in `ConvertPrecision`. ### Tickets: - N/A
1 parent eab547b commit ccfd8fd

File tree

4 files changed

+97
-0
lines changed

4 files changed

+97
-0
lines changed

src/common/transformations/src/transformations/convert_precision.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,26 @@ std::shared_ptr<Node> change_constant_precision<ov::element::Type_t::f32, ov::el
10911091
return new_constant;
10921092
}
10931093

1094+
template <>
1095+
std::shared_ptr<Node> change_constant_precision<ov::element::Type_t::bf16, ov::element::Type_t::f16>(
1096+
std::shared_ptr<ov::op::v0::Constant>& constant) {
1097+
using src_type = typename element_type_traits<ov::element::Type_t::bf16>::value_type;
1098+
using dst_type = typename element_type_traits<ov::element::Type_t::f16>::value_type;
1099+
1100+
const auto* src_data = constant->get_data_ptr<src_type>();
1101+
const auto size = shape_size(constant->get_shape());
1102+
1103+
auto new_constant = std::make_shared<ov::op::v0::Constant>(ov::element::Type_t::f16, constant->get_shape());
1104+
new_constant->output(0).set_names(constant->output(0).get_names());
1105+
auto* dst_data = const_cast<dst_type*>(reinterpret_cast<const dst_type*>(new_constant->get_data_ptr()));
1106+
if (dst_data == nullptr)
1107+
OPENVINO_THROW("Can't get destination data pointer");
1108+
1109+
ov::reference::convert_from_bf16_to_f16_with_clamp(src_data, dst_data, size);
1110+
1111+
return new_constant;
1112+
}
1113+
10941114
template <>
10951115
std::shared_ptr<Node> change_constant_precision<ov::element::Type_t::f16, ov::element::Type_t::f32>(
10961116
std::shared_ptr<ov::op::v0::Constant>& constant) {
@@ -1326,6 +1346,8 @@ bool fuse_type_to_constant(const std::shared_ptr<ov::Node>& node,
13261346
new_const = change_constant_precision<ov::element::Type_t::f64, ov::element::Type_t::f32>(constant);
13271347
} else if (from == ov::element::bf16 && to == ov::element::f32) {
13281348
new_const = change_constant_precision<ov::element::Type_t::bf16, ov::element::Type_t::f32>(constant);
1349+
} else if (from == ov::element::bf16 && to == ov::element::f16) {
1350+
new_const = change_constant_precision<ov::element::Type_t::bf16, ov::element::Type_t::f16>(constant);
13291351
} else if (from == ov::element::f32 && to == ov::element::f16) {
13301352
new_const = change_constant_precision<ov::element::Type_t::f32, ov::element::Type_t::f16>(constant);
13311353
} else if (from == ov::element::f16 && to == ov::element::f32) {

src/common/transformations/tests/utils/convert_precision.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,40 @@ TEST(TransformationTests, ConvertPrecision_Convert_clamp_1) {
382382
ASSERT_TRUE(res.valid) << res.message;
383383
}
384384

385+
TEST(TransformationTests, ConvertPrecision_Convert_clamp_bf16_f16) {
386+
// fp16 out of range should be clamped to [fp16_min, fp16_max]
387+
std::shared_ptr<Model> model(nullptr), model_ref(nullptr);
388+
{
389+
auto input = std::make_shared<opset4::Parameter>(element::f16, Shape{1, 1000, 3});
390+
auto const_node = opset10::Constant::create(element::bf16, Shape{3}, {100000.0f, -100000.0f, 10.0f});
391+
auto convert = std::make_shared<opset4::Convert>(const_node, element::f16);
392+
auto add_1 = make_shared<opset10::Add>(input, convert);
393+
model = std::make_shared<Model>(NodeVector{add_1}, ParameterVector{input});
394+
395+
pass::Manager manager;
396+
static const precisions_map precisions = {{element::bf16, element::f16}};
397+
manager.register_pass<pass::InitNodeInfo>();
398+
manager.register_pass<pass::ConvertPrecision>(precisions);
399+
manager.run_passes(model);
400+
}
401+
402+
{
403+
auto max_fp16 = static_cast<float>(std::numeric_limits<ov::float16>::max());
404+
auto input = std::make_shared<opset4::Parameter>(element::f16, Shape{1, 1000, 3});
405+
auto const_node = opset10::Constant::create(element::f16, Shape{3}, {max_fp16, -max_fp16, 10.0f});
406+
auto add_1 = make_shared<opset10::Add>(input, const_node);
407+
408+
model_ref = std::make_shared<Model>(NodeVector{add_1}, ParameterVector{input});
409+
}
410+
ASSERT_NO_THROW(check_rt_info(model));
411+
const auto fc = FunctionsComparator::with_default()
412+
.enable(FunctionsComparator::PRECISIONS)
413+
.enable(FunctionsComparator::CONST_VALUES)
414+
.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
415+
const auto res = fc.compare(model, model_ref);
416+
ASSERT_TRUE(res.valid) << res.message;
417+
}
418+
385419
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
386420
TEST(TransformationTests, ConvertPrecision_Convert_clamp_2) {
387421
#else

src/core/reference/include/openvino/reference/convert.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,8 @@ size_t count_out_of_f16_range(const float* arg, size_t count);
8282

8383
// Convert values from f32 to f16 with clamping to f16 min/max when value is out of normal finite numbers range
8484
void convert_from_f32_to_f16_with_clamp(const float* arg, float16* out, size_t count);
85+
86+
// Convert values from bf16 to f16 with clamping to f16 min/max when value is out of normal finite numbers range
87+
void convert_from_bf16_to_f16_with_clamp(const bfloat16* arg, float16* out, size_t count);
8588
} // namespace reference
8689
} // namespace ov

src/core/reference/src/op/convert.cpp

+38
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,22 @@ void jit_convert_vec<bfloat16, float16>(jit::Generator& gen, const Xbyak::RegExp
6666
gen.vmovdqu(gen.xword[dst], f16vec); // move result to destination
6767
}
6868

69+
template <>
70+
void jit_convert_vec<bfloat16, float16, true>(jit::Generator& gen, const Xbyak::RegExp& src, const Xbyak::RegExp& dst) {
71+
const auto f32vec = gen.ymm4;
72+
const auto f16vec = gen.xmm3;
73+
74+
auto upper_bound = gen.ymm5;
75+
auto lower_bound = gen.ymm6;
76+
77+
gen.vpmovzxwd(f32vec, gen.yword[src]); // load bf16 into tmp
78+
gen.vpslld(f32vec, f32vec, 16); // convert bf16->f32 by bit shift
79+
gen.vminps(f32vec, f32vec, upper_bound); // clamp f16 max
80+
gen.vmaxps(f32vec, f32vec, lower_bound); // clamp f16 lowest
81+
gen.vcvtps2ph(f16vec, f32vec, 0); // convert f32 -> f16
82+
gen.vmovdqu(gen.xword[dst], f16vec); // move result to destination
83+
}
84+
6985
template <>
7086
void jit_convert_vec<bfloat16, float>(jit::Generator& gen, const Xbyak::RegExp& src, const Xbyak::RegExp& dst) {
7187
const auto f32vec = gen.ymm4;
@@ -92,6 +108,11 @@ void jit_convert_vec_prepare<float, float16, true>(jit::Generator& gen) {
92108
gen.vmovdqu(lower_bound, gen.yword[addr]);
93109
}
94110

111+
template <>
112+
void jit_convert_vec_prepare<bfloat16, float16, true>(jit::Generator& gen) {
113+
jit_convert_vec_prepare<float, float16, true>(gen);
114+
}
115+
95116
template <>
96117
void jit_convert_vec<float, float16, true>(jit::Generator& gen, const Xbyak::RegExp& src, const Xbyak::RegExp& dst) {
97118
auto f16vec = gen.xmm3;
@@ -552,6 +573,23 @@ void convert_from_f32_to_f16_with_clamp(const float* arg, float16* out, size_t c
552573
#endif // defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
553574
}
554575

576+
void convert_from_bf16_to_f16_with_clamp(const bfloat16* arg, float16* out, size_t count) {
577+
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
578+
convert_impl<bfloat16, float16, true>(arg, out, count);
579+
#else
580+
// FIXME CVS-125496: duplicate and stub for ARM, provide optimized solution
581+
for (size_t i = 0; i < count; ++i) {
582+
if (arg[i] > std::numeric_limits<ov::float16>::max()) {
583+
out[i] = std::numeric_limits<ov::float16>::max();
584+
} else if (arg[i] < std::numeric_limits<ov::float16>::lowest()) {
585+
out[i] = std::numeric_limits<ov::float16>::lowest();
586+
} else {
587+
out[i] = static_cast<ov::float16>(arg[i]);
588+
}
589+
}
590+
#endif // defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
591+
}
592+
555593
size_t count_out_of_f16_range(const float* arg, size_t count) {
556594
size_t num_out_of_range = 0;
557595

0 commit comments

Comments
 (0)