Skip to content

Commit fae63f0

Browse files
jianan-guazhai219
authored andcommittedDec 3, 2024
[FORK] [FEATURE] cpu: add inner product with sparse packed weights
3.5 squash list: [FORK][FIX] separate _blk_off() and _blk_off_sparse() to fix perf issues [FORK][FEATURE] cpu: add inner product with sparse packed weights [FORK][FIX] changed comp_tile_len data type from int16_t to int [FORK][FEATURE] cpu: add inner product with sparse packed weights
1 parent 96b4ba7 commit fae63f0

30 files changed

+782
-217
lines changed
 

‎include/oneapi/dnnl/dnnl.h

+13-3
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,19 @@ dnnl_status_t DNNL_API dnnl_memory_desc_create_with_packed_encoding(
976976
dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
977977
dnnl_data_type_t data_type, dnnl_dim_t nnz);
978978
#endif
979+
/// Initializes a sparse descriptor.
980+
///
981+
/// @param memory_desc Output memory descriptor.
982+
/// @param encoding Encoding.
983+
/// @param ndims Number of dimensions.
984+
/// @param dims Array of dimensions.
985+
/// @param data_type Elements data type.
986+
/// @returns #dnnl_success on success and a status describing the error
987+
/// otherwise.
988+
dnnl_status_t DNNL_API dnnl_memory_desc_create_sparse(
989+
dnnl_memory_desc_t *memory_desc,
990+
dnnl_sparse_encoding_t encoding, int ndims,
991+
const dnnl_dims_t dims, dnnl_data_type_t data_type);
979992

980993
/// Creates a memory descriptor for a region inside an area
981994
/// described by an existing memory descriptor.
@@ -1230,7 +1243,6 @@ size_t DNNL_API dnnl_memory_desc_get_size(const_dnnl_memory_desc_t memory_desc);
12301243
size_t DNNL_API dnnl_memory_desc_get_size_v2(
12311244
const_dnnl_memory_desc_t memory_desc, int index);
12321245
#endif
1233-
12341246
/// Returns the size of data type.
12351247
///
12361248
/// @param data_type Data type.
@@ -1283,7 +1295,6 @@ dnnl_status_t DNNL_API dnnl_memory_create_v2(dnnl_memory_t *memory,
12831295
const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
12841296
int nhandles, void **handles);
12851297
#endif
1286-
12871298
/// Returns the memory descriptor for a memory object.
12881299
///
12891300
/// @param memory Memory object.
@@ -1395,7 +1406,6 @@ dnnl_status_t DNNL_API dnnl_memory_unmap_data(
13951406
dnnl_status_t DNNL_API dnnl_memory_unmap_data_v2(
13961407
const_dnnl_memory_t memory, void *mapped_ptr, int index);
13971408
#endif
1398-
13991409
/// Returns memory object's data handle.
14001410
///
14011411
/// @param memory Memory object.

‎include/oneapi/dnnl/dnnl.hpp

+91-8
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,7 @@ struct memory : public handle<dnnl_memory_t> {
923923
/// Format kind for sparse tensors.
924924
sparse = dnnl_format_kind_sparse,
925925
#endif
926+
sparsed = dnnl_format_sparse,
926927
/// A special format kind that indicates that tensor format is opaque.
927928
opaque = dnnl_format_kind_opaque,
928929
};
@@ -2753,7 +2754,6 @@ struct memory : public handle<dnnl_memory_t> {
27532754
/// A memory descriptor.
27542755
struct desc : public handle<dnnl_memory_desc_t> {
27552756
using handle<dnnl_memory_desc_t>::handle;
2756-
27572757
friend struct memory;
27582758

27592759
/// Constructs a zero (empty) memory descriptor. Such a memory
@@ -2944,6 +2944,31 @@ struct memory : public handle<dnnl_memory_t> {
29442944
reset(md);
29452945
}
29462946

2947+
/// @fork
2948+
/// Copy constructor for memory::desc
2949+
/// Ensures deep copy (underlying C structure is copied as well)
2950+
/// To preserve behavior of 2.x oneDNN versions
2951+
///
2952+
/// @param desc memory descriptor to copy.
2953+
desc(const memory::desc& adesc) {
2954+
auto cdesc = adesc.get();
2955+
dnnl_memory_desc_t cloned_md = nullptr;
2956+
dnnl_memory_desc_clone(&cloned_md, cdesc);
2957+
2958+
reset(cloned_md);
2959+
}
2960+
2961+
desc sparse_desc(const dims &adims, data_type adata_type,
2962+
bool allow_empty = false) {
2963+
dnnl_memory_desc_t md = nullptr;
2964+
dnnl_status_t status = dnnl_memory_desc_create_sparse(&md, dnnl_sparse_encoding_packed,
2965+
(int)adims.size(), adims.data(), convert_to_c(adata_type));
2966+
2967+
if (!allow_empty)
2968+
error::wrap_c_api(status,
2969+
"could not construct a memory descriptor with sparse format");
2970+
return desc(md);
2971+
}
29472972
/// Constructs a memory descriptor for a region inside an area
29482973
/// described by this memory descriptor.
29492974
//
@@ -3192,9 +3217,9 @@ struct memory : public handle<dnnl_memory_t> {
31923217
/// Returns the data type of the memory descriptor.
31933218
///
31943219
/// @returns The data type.
3195-
memory::data_type get_data_type() const {
3196-
return query_data_type(query::data_type);
3197-
}
3220+
// memory::data_type get_data_type() const {
3221+
// return query_data_type(query::data_type);
3222+
// }
31983223
#endif
31993224

32003225
/// Returns the format kind of the memory descriptor.
@@ -3209,6 +3234,30 @@ struct memory : public handle<dnnl_memory_t> {
32093234
: dnnl::memory::format_kind::undef;
32103235
}
32113236

3237+
/// Returns the format kind of the memory descriptor.
3238+
///
3239+
/// @returns the format kind.
3240+
dnnl_sparse_encoding_t get_sparse_encoding() const {
3241+
dnnl_sparse_encoding_t sparse_encoding;
3242+
dnnl_status_t status = dnnl_memory_desc_query(
3243+
get(), dnnl_query_sparse_encoding, &sparse_encoding);
3244+
return status == dnnl_success
3245+
? sparse_encoding
3246+
: dnnl_sparse_encoding_undef;
3247+
}
3248+
3249+
/// Returns the data type of the memory descriptor.
3250+
///
3251+
/// @returns The data type.
3252+
memory::data_type get_data_type() const {
3253+
dnnl_data_type_t data_type;
3254+
dnnl_status_t status = dnnl_memory_desc_query(
3255+
get(), dnnl_query_data_type, &data_type);
3256+
return status == dnnl_success
3257+
? static_cast<dnnl::memory::data_type>(data_type)
3258+
: dnnl::memory::data_type::undef;
3259+
}
3260+
32123261
/// Returns dimensions of the memory descriptor.
32133262
///
32143263
/// Potentially expensive due to the data copy involved.
@@ -3386,6 +3435,44 @@ struct memory : public handle<dnnl_memory_t> {
33863435
reset(result);
33873436
}
33883437
#else
3438+
/// Constructs a memory object.
3439+
///
3440+
/// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
3441+
/// object will have the underlying buffer set. In this case, the buffer
3442+
/// will be initialized as if #dnnl::memory::set_data_handle() had been
3443+
/// called.
3444+
///
3445+
/// @sa memory::set_data_handle()
3446+
///
3447+
/// @param md Memory descriptor.
3448+
/// @param aengine Engine to store the data on.
3449+
/// @param handle Handle of the memory buffer to use.
3450+
/// - A pointer to the user-allocated buffer. In this case the library
3451+
/// doesn't own the buffer.
3452+
/// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to
3453+
/// allocate the buffer for the memory object. In this case the
3454+
/// library owns the buffer.
3455+
/// - #DNNL_MEMORY_NONE to create dnnl::memory without an underlying
3456+
/// buffer.
3457+
// memory(const desc &md, const engine &aengine, void *handle) {
3458+
// dnnl_memory_t result;
3459+
// error::wrap_c_api(
3460+
// dnnl_memory_create(&result, md.get(), aengine.get(), handle),
3461+
// "could not create a memory object");
3462+
// reset(result);
3463+
// }
3464+
3465+
/// Constructs a memory object.
3466+
///
3467+
/// The underlying buffer(s) for the memory will be allocated by the
3468+
/// library.
3469+
///
3470+
/// @param md Memory descriptor.
3471+
/// @param aengine Engine to store the data on.
3472+
// memory(const desc &md, const engine &aengine)
3473+
// : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {}
3474+
#endif
3475+
33893476
/// Constructs a memory object.
33903477
///
33913478
/// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
@@ -3413,15 +3500,11 @@ struct memory : public handle<dnnl_memory_t> {
34133500
reset(result);
34143501
}
34153502

3416-
/// Constructs a memory object.
3417-
///
34183503
/// The underlying buffer for the memory will be allocated by the library.
3419-
///
34203504
/// @param md Memory descriptor.
34213505
/// @param aengine Engine to store the data on.
34223506
memory(const desc &md, const engine &aengine)
34233507
: memory(md, aengine, DNNL_MEMORY_ALLOCATE) {}
3424-
#endif
34253508

34263509
/// Returns the associated memory descriptor.
34273510
desc get_desc() const {

‎include/oneapi/dnnl/dnnl_debug.h

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ const char DNNL_API *dnnl_fmt_tag2str(dnnl_format_tag_t v);
4444
const char DNNL_API *dnnl_prop_kind2str(dnnl_prop_kind_t v);
4545
const char DNNL_API *dnnl_prim_kind2str(dnnl_primitive_kind_t v);
4646
const char DNNL_API *dnnl_alg_kind2str(dnnl_alg_kind_t v);
47+
const char DNNL_API *dnnl_sparse_encoding2str(dnnl_sparse_encoding_t v);
4748
const char DNNL_API *dnnl_rnn_flags2str(dnnl_rnn_flags_t v);
4849
const char DNNL_API *dnnl_rnn_direction2str(dnnl_rnn_direction_t v);
4950
const char DNNL_API *dnnl_scratchpad_mode2str(dnnl_scratchpad_mode_t v);

‎include/oneapi/dnnl/dnnl_types.h

+50
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ typedef enum {
5555
/// Format kind for sparse tensors.
5656
dnnl_format_kind_sparse,
5757
#endif
58+
/// Format for sparse data.
59+
dnnl_format_sparse,
5860
/// Parameter to allow internal only format kinds without undefined
5961
/// behavior. This parameter is chosen to be valid for so long as
6062
/// sizeof(int) >= 2.
@@ -2322,6 +2324,52 @@ typedef struct dnnl_memory_desc *dnnl_memory_desc_t;
23222324
/// A memory descriptor handle.
23232325
typedef const struct dnnl_memory_desc *const_dnnl_memory_desc_t;
23242326

2327+
/// Sparse encodings.
2328+
typedef enum {
2329+
dnnl_sparse_encoding_undef = 0,
2330+
dnnl_sparse_encoding_any,
2331+
dnnl_sparse_encoding_packed,
2332+
dnnl_sparse_encoding_csr,
2333+
dnnl_sparse_encoding_coo,
2334+
} dnnl_sparse_encoding_t;
2335+
2336+
/* typedef struct dnnl_sparse_desc *dnnl_sparse_desc_t; */
2337+
/* typedef const struct dnnl_sparse_desc *const_dnnl_sparse_desc_t; */
2338+
2339+
/// Flags for memory special features
2340+
typedef enum {
2341+
dnnl_memory_extra_flag_none = 0x0U,
2342+
/// Indicates the weights have an additional buffer, that depends on the
2343+
/// @p compensation_mask.
2344+
///
2345+
/// For instance, in 4D case with the compensation mask equals (1 << 0)
2346+
/// the additional buffer would consist of OC values:
2347+
/// O[oc : 0,OC] =
2348+
/// -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) }
2349+
dnnl_memory_extra_flag_compensation_conv_s8s8 = 0x1U,
2350+
dnnl_memory_extra_flag_scale_adjust = 0x2U,
2351+
dnnl_memory_extra_flag_rnn_u8s8_compensation = 0x4U,
2352+
dnnl_memory_extra_flag_gpu_rnn_u8s8_compensation
2353+
= dnnl_memory_extra_flag_rnn_u8s8_compensation,
2354+
dnnl_memory_extra_flag_compensation_conv_asymmetric_src = 0x8U,
2355+
dnnl_memory_extra_flag_rnn_s8s8_compensation = 0x16U,
2356+
} dnnl_memory_extra_flags_t;
2357+
2358+
/// Description of extra information stored in memory
2359+
typedef struct {
2360+
/// The flags contain arbitrary extra information, such as compensation.
2361+
/// @sa dnnl_memory_extra_flags_t
2362+
uint64_t flags;
2363+
/// Compensation mask
2364+
int compensation_mask;
2365+
/// Scale applied to the data
2366+
float scale_adjust;
2367+
/// Compensation mask for asymmetric quantization
2368+
int asymm_compensation_mask;
2369+
/// For future backwards compatibility
2370+
char reserved[60];
2371+
} dnnl_memory_extra_desc_t;
2372+
23252373
/// @struct dnnl_memory
23262374
/// An opaque structure to describe a memory.
23272375
struct dnnl_memory;
@@ -2836,6 +2884,8 @@ typedef enum {
28362884
dnnl_query_num_handles_s32, ///< Number of buffers required for a memory
28372885
/// descriptor
28382886
#endif
2887+
dnnl_query_sparse_encoding,
2888+
28392889
// Max value to prevent UB for internal use only dnnl_query_t
28402890
dnnl_query_max = 0x7fff,
28412891
} dnnl_query_t;

‎scripts/generate_dnnl_debug.py

+4
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def header_benchdnn(body):
129129
#ifdef DNNL_EXPERIMENTAL_SPARSE
130130
const char *sparse_encoding2str(dnnl_sparse_encoding_t encoding);
131131
#endif
132+
const char *sparse_encoding2str(dnnl_sparse_encoding_t encoding);
132133
133134
/* engine kind */
134135
const char *engine_kind2str(dnnl_engine_kind_t kind);
@@ -183,6 +184,9 @@ def source_benchdnn(body):
183184
return dnnl_sparse_encoding2str(encoding);
184185
}
185186
#endif
187+
const char *sparse_encoding2str(dnnl_sparse_encoding_t encoding) {
188+
return dnnl_sparse_encoding2str(encoding);
189+
}
186190
187191
const char *engine_kind2str(dnnl_engine_kind_t kind) {
188192
return dnnl_engine_kind2str(kind);

‎src/common/c_types_map.hpp

+22-11
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,13 @@ const sparse_encoding_t packed = dnnl_packed;
217217
} // namespace sparse_encoding
218218
#else
219219
// Declare dummy values to avoid guarding internal implementation.
220-
using sparse_encoding_t = int;
221-
namespace sparse_encoding {
222-
const sparse_encoding_t undef = 0;
223-
const sparse_encoding_t csr = 1;
224-
const sparse_encoding_t packed = 2;
225-
const sparse_encoding_t coo = 3;
226-
} // namespace sparse_encoding
220+
// using sparse_encoding_t = int;
221+
// namespace sparse_encoding {
222+
// const sparse_encoding_t undef = 0;
223+
// const sparse_encoding_t csr = 1;
224+
// const sparse_encoding_t packed = 2;
225+
// const sparse_encoding_t coo = 3;
226+
// } // namespace sparse_encoding
227227
#endif
228228

229229
using format_kind_t = dnnl_format_kind_t;
@@ -235,14 +235,15 @@ const format_kind_t opaque = dnnl_format_kind_opaque;
235235
#ifdef DNNL_EXPERIMENTAL_SPARSE
236236
const format_kind_t sparse = dnnl_format_kind_sparse;
237237
#else
238-
const format_kind_t sparse = static_cast<format_kind_t>(4);
238+
// const format_kind_t sparse = static_cast<format_kind_t>(4);
239239
#endif
240240

241241
// Internal only format kinds.
242242
const format_kind_t internal_only_start = (format_kind_t)(1 << 8);
243243
const format_kind_t wino = internal_only_start;
244244
const format_kind_t rnn_packed = (format_kind_t)(internal_only_start + 1);
245245
const format_kind_t cublaslt_blocked = (format_kind_t)(internal_only_start + 2);
246+
const format_kind_t sparse = dnnl_format_sparse;
246247
} // namespace format_kind
247248

248249
#ifdef DNNL_EXPERIMENTAL_PROFILING
@@ -1915,6 +1916,15 @@ const rnn_flags_t diff_weights_overwrite
19151916
= dnnl_rnn_flags_diff_weights_overwrite;
19161917
} // namespace rnn_flags
19171918

1919+
using sparse_encoding_t = dnnl_sparse_encoding_t;
1920+
namespace sparse_encoding {
1921+
const sparse_encoding_t undef = dnnl_sparse_encoding_undef;
1922+
const sparse_encoding_t any = dnnl_sparse_encoding_any;
1923+
const sparse_encoding_t packed = dnnl_sparse_encoding_packed;
1924+
const sparse_encoding_t csr = dnnl_sparse_encoding_csr;
1925+
const sparse_encoding_t coo = dnnl_sparse_encoding_coo;
1926+
} // namespace sparse_encoding
1927+
19181928
using engine_kind_t = dnnl_engine_kind_t;
19191929
namespace engine_kind {
19201930
const engine_kind_t any_engine = dnnl_any_engine;
@@ -2051,15 +2061,16 @@ const query_t sparse_encoding = dnnl_query_sparse_encoding;
20512061
const query_t nnz_s64 = dnnl_query_nnz_s64;
20522062
const query_t num_handles_s32 = dnnl_query_num_handles_s32;
20532063
#else
2054-
const query_t sparse_encoding = static_cast<query_t>(266);
2055-
const query_t nnz_s64 = static_cast<query_t>(267);
2056-
const query_t num_handles_s32 = static_cast<query_t>(268);
2064+
// const query_t sparse_encoding = static_cast<query_t>(266);
2065+
// const query_t nnz_s64 = static_cast<query_t>(267);
2066+
// const query_t num_handles_s32 = static_cast<query_t>(268);
20572067
#endif
20582068

20592069
// Internal only query kinds.
20602070
const query_t internal_only_start = (query_t)(1 << 12);
20612071
const query_t zero_pad_d = internal_only_start;
20622072
const query_t preferred_gpu_threads_per_eu = (query_t)(internal_only_start + 1);
2073+
const query_t sparse_encoding = dnnl_query_sparse_encoding;
20632074
} // namespace query
20642075

20652076
using rnn_direction_t = dnnl_rnn_direction_t;

‎src/common/dnnl_debug.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ const char *dnnl_fmt_kind2str(dnnl_format_kind_t v) {
5454
|| v == format_kind::cublaslt_blocked)
5555
return "opaque";
5656
if (v == dnnl_format_kind_max) return "max";
57+
if (v == dnnl_format_sparse) return "format_sparse";
5758
assert(!"unknown fmt_kind");
5859
return "unknown fmt_kind";
5960
}

‎src/common/dnnl_debug_autogenerated.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -1854,6 +1854,14 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) {
18541854
return "unknown alg_kind";
18551855
}
18561856

1857+
const char *dnnl_sparse_encoding2str(dnnl_sparse_encoding_t v) {
1858+
if (v == dnnl_sparse_encoding_undef) return "undef";
1859+
if (v == dnnl_sparse_encoding_any) return "any";
1860+
if (v == dnnl_sparse_encoding_packed) return "sparse_encoding_packed";
1861+
assert(!"unknown sparse_encoding");
1862+
return "unknown sparse_encoding";
1863+
}
1864+
18571865
const char *dnnl_rnn_flags2str(dnnl_rnn_flags_t v) {
18581866
if (v == dnnl_rnn_flags_undef) return "undef";
18591867
if (v == dnnl_rnn_flags_diff_weights_overwrite) return "rnn_flags_diff_weights_overwrite";

0 commit comments

Comments
 (0)