4
4
5
5
#include " dnnl_extension_utils.h"
6
6
7
- #include " utils/general_utils.h"
8
7
#include < oneapi/dnnl/dnnl.hpp>
9
8
#include " memory_desc/dnnl_blocked_memory_desc.h"
10
- #include " onednn/iml_type_mapper.h"
11
- #include < common/primitive_desc.hpp>
12
9
#include < common/primitive_desc_iface.hpp>
13
10
14
- #include < vector>
15
-
16
11
using namespace dnnl ;
17
12
18
13
namespace ov {
19
14
namespace intel_cpu {
20
15
21
- uint8_t DnnlExtensionUtils::sizeOfDataType (dnnl:: memory::data_type dataType) {
16
+ uint8_t DnnlExtensionUtils::sizeOfDataType (memory::data_type dataType) {
22
17
switch (dataType) {
23
- case dnnl::memory::data_type::f32:
24
- return 4 ;
25
- case dnnl::memory::data_type::s32:
18
+ case memory::data_type::f64:
19
+ case memory::data_type::s64:
20
+ return 8 ;
21
+ case memory::data_type::f32:
22
+ case memory::data_type::s32:
26
23
return 4 ;
27
- case dnnl::memory::data_type::bf16:
24
+ case memory::data_type::bf16:
25
+ case memory::data_type::f16:
28
26
return 2 ;
29
- case dnnl:: memory::data_type::s8:
30
- return 1 ;
31
- case dnnl:: memory::data_type::u8 :
27
+ case memory::data_type::s8:
28
+ case memory::data_type::u8:
29
+ case memory::data_type::bin :
32
30
return 1 ;
33
- case dnnl::memory::data_type::bin:
34
- return 1 ;
35
- case dnnl::memory::data_type::f16:
36
- return 2 ;
37
- case dnnl::memory::data_type::undef:
31
+ case memory::data_type::undef:
38
32
return 0 ;
39
33
default :
40
- IE_THROW () << " Unsupported data type. " ;
34
+ IE_THROW () << " Unsupported data type: " << DataTypeToIEPrecision (dataType) ;
41
35
}
42
36
}
43
37
44
38
memory::data_type DnnlExtensionUtils::IEPrecisionToDataType (const InferenceEngine::Precision& prec) {
45
39
switch (prec) {
40
+ case InferenceEngine::Precision::FP64:
41
+ return memory::data_type::f64;
42
+ case InferenceEngine::Precision::I64:
43
+ return memory::data_type::s64;
46
44
case InferenceEngine::Precision::FP32:
47
45
return memory::data_type::f32;
48
46
case InferenceEngine::Precision::I32:
@@ -68,6 +66,10 @@ memory::data_type DnnlExtensionUtils::IEPrecisionToDataType(const InferenceEngin
68
66
69
67
InferenceEngine::Precision DnnlExtensionUtils::DataTypeToIEPrecision (memory::data_type dataType) {
70
68
switch (dataType) {
69
+ case memory::data_type::f64:
70
+ return InferenceEngine::Precision::FP64;
71
+ case memory::data_type::s64:
72
+ return InferenceEngine::Precision::I64;
71
73
case memory::data_type::f32:
72
74
return InferenceEngine::Precision::FP32;
73
75
case memory::data_type::s32:
@@ -90,11 +92,11 @@ InferenceEngine::Precision DnnlExtensionUtils::DataTypeToIEPrecision(memory::dat
90
92
}
91
93
}
92
94
93
- Dim DnnlExtensionUtils::convertToDim (const dnnl:: memory::dim &dim) {
95
+ Dim DnnlExtensionUtils::convertToDim (const memory::dim &dim) {
94
96
return dim == DNNL_RUNTIME_DIM_VAL ? Shape::UNDEFINED_DIM : static_cast <size_t >(dim);
95
97
}
96
- dnnl:: memory::dim DnnlExtensionUtils::convertToDnnlDim (const Dim &dim) {
97
- return dim == Shape::UNDEFINED_DIM ? DNNL_RUNTIME_DIM_VAL : static_cast <dnnl:: memory::dim>(dim);
98
+ memory::dim DnnlExtensionUtils::convertToDnnlDim (const Dim &dim) {
99
+ return dim == Shape::UNDEFINED_DIM ? DNNL_RUNTIME_DIM_VAL : static_cast <memory::dim>(dim);
98
100
}
99
101
100
102
VectorDims DnnlExtensionUtils::convertToVectorDims (const memory::dims& dims) {
@@ -133,19 +135,19 @@ memory::format_tag DnnlExtensionUtils::GetPlainFormatByRank(size_t rank) {
133
135
}
134
136
}
135
137
136
- DnnlMemoryDescPtr DnnlExtensionUtils::makeDescriptor (const dnnl:: memory::desc &desc) {
138
+ DnnlMemoryDescPtr DnnlExtensionUtils::makeDescriptor (const memory::desc &desc) {
137
139
return makeDescriptor (desc.get ());
138
140
}
139
141
140
142
DnnlMemoryDescPtr DnnlExtensionUtils::makeDescriptor (const_dnnl_memory_desc_t desc) {
141
- if (desc->format_kind == dnnl:: impl::format_kind_t ::dnnl_blocked) {
143
+ if (desc->format_kind == impl::format_kind_t ::dnnl_blocked) {
142
144
return std::shared_ptr<DnnlBlockedMemoryDesc>(new DnnlBlockedMemoryDesc (desc));
143
145
} else {
144
146
return std::shared_ptr<DnnlMemoryDesc>(new DnnlMemoryDesc (desc));
145
147
}
146
148
}
147
149
148
- size_t DnnlExtensionUtils::getMemSizeForDnnlDesc (const dnnl:: memory::desc& desc) {
150
+ size_t DnnlExtensionUtils::getMemSizeForDnnlDesc (const memory::desc& desc) {
149
151
auto tmpDesc = desc;
150
152
151
153
const auto offset0 = tmpDesc.get ()->offset0 ;
@@ -167,8 +169,8 @@ std::shared_ptr<DnnlBlockedMemoryDesc> DnnlExtensionUtils::makeUndefinedDesc(con
167
169
}
168
170
}
169
171
170
- DnnlMemoryDescPtr DnnlExtensionUtils::query_md (const const_dnnl_primitive_desc_t & pd, const dnnl:: query& what, int idx) {
171
- auto query = dnnl:: convert_to_c (what);
172
+ DnnlMemoryDescPtr DnnlExtensionUtils::query_md (const const_dnnl_primitive_desc_t & pd, const query& what, int idx) {
173
+ auto query = convert_to_c (what);
172
174
const auto * cdesc = dnnl_primitive_desc_query_md (pd, query, idx);
173
175
174
176
if (!cdesc)
0 commit comments