9
9
#include " memory_desc/cpu_memory_desc_utils.h"
10
10
#include " nodes/common/cpu_memcpy.h"
11
11
#include " nodes/reorder.h"
12
+ #include " utils/bfloat16.hpp"
12
13
#include " utils/debug_capabilities.h"
13
14
#if defined(__linux__)
14
15
# include < sys/syscall.h> /* Definition of SYS_* constants */
@@ -30,19 +31,44 @@ BlockedMemoryDescPtr IMemory::getDescWithType<BlockedMemoryDesc, 0, 0>() const {
30
31
}
31
32
32
33
namespace {
33
- inline void setSubnormalsToZero (float * data, size_t size) {
34
+ inline void setSubnormalsToZeroAndbf16Saturation (float * data, size_t size, bool ftz, bool bf16saturation ) {
34
35
uint32_t * u32data = reinterpret_cast <uint32_t *>(data);
35
- for (size_t i = 0 ; i < size; ++i) {
36
- if ((u32data[i] & (0xFF << 23 )) == 0 ) {
37
- u32data[i] = 0 ;
36
+ float * floatdata = reinterpret_cast <float *>(data);
37
+ if (ftz && bf16saturation) {
38
+ for (size_t i = 0 ; i < size; ++i) {
39
+ if ((u32data[i] & (0xFF << 23 )) == 0 ) {
40
+ u32data[i] = 0 ;
41
+ } else if (!std::isnan (floatdata[i]) && !std::isinf (floatdata[i])) {
42
+ floatdata[i] = (floatdata[i] < static_cast <float >(std::numeric_limits<ov::bfloat16>::lowest ()))
43
+ ? static_cast <float >(std::numeric_limits<ov::bfloat16>::lowest ())
44
+ : (floatdata[i] > static_cast <float >(std::numeric_limits<ov::bfloat16>::max ()))
45
+ ? static_cast <float >(std::numeric_limits<ov::bfloat16>::max ())
46
+ : floatdata[i];
47
+ }
48
+ }
49
+ } else if (ftz) {
50
+ for (size_t i = 0 ; i < size; ++i) {
51
+ if ((u32data[i] & (0xFF << 23 )) == 0 ) {
52
+ u32data[i] = 0 ;
53
+ }
54
+ }
55
+ } else if (bf16saturation) {
56
+ for (size_t i = 0 ; i < size; ++i) {
57
+ if (!std::isnan (floatdata[i]) && !std::isinf (floatdata[i])) {
58
+ floatdata[i] = (floatdata[i] < static_cast <float >(std::numeric_limits<ov::bfloat16>::lowest ()))
59
+ ? static_cast <float >(std::numeric_limits<ov::bfloat16>::lowest ())
60
+ : (floatdata[i] > static_cast <float >(std::numeric_limits<ov::bfloat16>::max ()))
61
+ ? static_cast <float >(std::numeric_limits<ov::bfloat16>::max ())
62
+ : floatdata[i];
63
+ }
38
64
}
39
65
}
40
66
}
41
67
42
- void transferData (const IMemory& src, const IMemory& dst, bool ftz) {
68
+ void transferData (const IMemory& src, const IMemory& dst, bool ftz, bool bf16saturation ) {
43
69
node::Reorder::reorderData (src, dst);
44
70
45
- if (!ftz) {
71
+ if (!ftz && !bf16saturation ) {
46
72
return ;
47
73
}
48
74
if (src.getDesc ().getPrecision () != ov::element::f32 || dst.getDesc ().getPrecision () != ov::element::f32) {
@@ -62,7 +88,7 @@ void transferData(const IMemory& src, const IMemory& dst, bool ftz) {
62
88
// actual FTZ
63
89
auto * memData = static_cast <float *>(dst.getData ());
64
90
memData += offset;
65
- setSubnormalsToZero (memData, dst.getSize () / sizeof (float ));
91
+ setSubnormalsToZeroAndbf16Saturation (memData, dst.getSize () / sizeof (float ), ftz, bf16saturation );
66
92
}
67
93
68
94
} // namespace
@@ -125,11 +151,11 @@ void Memory::create(MemoryDescPtr desc, const void* data, bool pads_zeroing) {
125
151
}
126
152
}
127
153
128
- void Memory::load (const IMemory& src, bool ftz) const {
154
+ void Memory::load (const IMemory& src, bool ftz, bool bf16saturation ) const {
129
155
if (src.getDesc ().getPrecision () == element::string) {
130
156
OPENVINO_THROW (" [CPU] Memory object cannot load string data." );
131
157
}
132
- transferData (src, *this , ftz);
158
+ transferData (src, *this , ftz, bf16saturation );
133
159
}
134
160
135
161
void Memory::nullify () {
@@ -273,12 +299,12 @@ StringMemory::StringMemory(dnnl::engine engine, MemoryDescPtr desc, const void*
273
299
}
274
300
}
275
301
276
- void StringMemory::load (const IMemory& src, bool ftz) const {
302
+ void StringMemory::load (const IMemory& src, bool ftz, bool bf16saturation ) const {
277
303
if (src.getDesc ().getPrecision () != element::string) {
278
304
OPENVINO_THROW (" [CPU] String memory cannot load a non-string object." );
279
305
}
280
306
281
- transferData (src, *this , false );
307
+ transferData (src, *this , false , false );
282
308
}
283
309
284
310
void * StringMemory::getData () const {
@@ -472,11 +498,11 @@ void StaticMemory::redefineDesc(MemoryDescPtr desc) {
472
498
OPENVINO_THROW (" Unexpected: Memory descriptor may not be modified in StaticMemory object" );
473
499
}
474
500
475
- void StaticMemory::load (const IMemory& src, bool ftz) const {
501
+ void StaticMemory::load (const IMemory& src, bool ftz, bool bf16saturation ) const {
476
502
if (src.getDesc ().getPrecision () == element::string) {
477
503
OPENVINO_THROW (" [CPU] StaticMemory cannot load string data." );
478
504
}
479
- transferData (src, *this , ftz);
505
+ transferData (src, *this , ftz, bf16saturation );
480
506
}
481
507
482
508
MemoryBlockPtr StaticMemory::getMemoryBlock () const {
0 commit comments