@@ -923,6 +923,7 @@ struct memory : public handle<dnnl_memory_t> {
923
923
/// Format kind for sparse tensors.
924
924
sparse = dnnl_format_kind_sparse,
925
925
#endif
926
+ sparsed = dnnl_format_sparse,
926
927
/// A special format kind that indicates that tensor format is opaque.
927
928
opaque = dnnl_format_kind_opaque,
928
929
};
@@ -2753,7 +2754,6 @@ struct memory : public handle<dnnl_memory_t> {
2753
2754
/// A memory descriptor.
2754
2755
struct desc : public handle<dnnl_memory_desc_t> {
2755
2756
using handle<dnnl_memory_desc_t>::handle;
2756
-
2757
2757
friend struct memory;
2758
2758
2759
2759
/// Constructs a zero (empty) memory descriptor. Such a memory
@@ -2944,6 +2944,31 @@ struct memory : public handle<dnnl_memory_t> {
2944
2944
reset(md);
2945
2945
}
2946
2946
2947
+ /// @fork
2948
+ /// Copy constructor for memory::desc
2949
+ /// Ensures deep copy (underlying C structure is copied as well)
2950
+ /// To preserve behavior of 2.x oneDNN versions
2951
+ ///
2952
+ /// @param desc memory descriptor to copy.
2953
+ desc(const memory::desc& adesc) {
2954
+ auto cdesc = adesc.get();
2955
+ dnnl_memory_desc_t cloned_md = nullptr;
2956
+ dnnl_memory_desc_clone(&cloned_md, cdesc);
2957
+
2958
+ reset(cloned_md);
2959
+ }
2960
+
2961
+ desc sparse_desc(const dims &adims, data_type adata_type,
2962
+ bool allow_empty = false) {
2963
+ dnnl_memory_desc_t md = nullptr;
2964
+ dnnl_status_t status = dnnl_memory_desc_create_sparse(&md, dnnl_sparse_encoding_packed,
2965
+ (int)adims.size(), adims.data(), convert_to_c(adata_type));
2966
+
2967
+ if (!allow_empty)
2968
+ error::wrap_c_api(status,
2969
+ "could not construct a memory descriptor with sparse format");
2970
+ return desc(md);
2971
+ }
2947
2972
/// Constructs a memory descriptor for a region inside an area
2948
2973
/// described by this memory descriptor.
2949
2974
//
@@ -3192,9 +3217,9 @@ struct memory : public handle<dnnl_memory_t> {
3192
3217
/// Returns the data type of the memory descriptor.
3193
3218
///
3194
3219
/// @returns The data type.
3195
- memory::data_type get_data_type() const {
3196
- return query_data_type(query::data_type);
3197
- }
3220
+ // memory::data_type get_data_type() const {
3221
+ // return query_data_type(query::data_type);
3222
+ // }
3198
3223
#endif
3199
3224
3200
3225
/// Returns the format kind of the memory descriptor.
@@ -3209,6 +3234,30 @@ struct memory : public handle<dnnl_memory_t> {
3209
3234
: dnnl::memory::format_kind::undef;
3210
3235
}
3211
3236
3237
+ /// Returns the format kind of the memory descriptor.
3238
+ ///
3239
+ /// @returns the format kind.
3240
+ dnnl_sparse_encoding_t get_sparse_encoding() const {
3241
+ dnnl_sparse_encoding_t sparse_encoding;
3242
+ dnnl_status_t status = dnnl_memory_desc_query(
3243
+ get(), dnnl_query_sparse_encoding, &sparse_encoding);
3244
+ return status == dnnl_success
3245
+ ? sparse_encoding
3246
+ : dnnl_sparse_encoding_undef;
3247
+ }
3248
+
3249
+ /// Returns the data type of the memory descriptor.
3250
+ ///
3251
+ /// @returns The data type.
3252
+ memory::data_type get_data_type() const {
3253
+ dnnl_data_type_t data_type;
3254
+ dnnl_status_t status = dnnl_memory_desc_query(
3255
+ get(), dnnl_query_data_type, &data_type);
3256
+ return status == dnnl_success
3257
+ ? static_cast<dnnl::memory::data_type>(data_type)
3258
+ : dnnl::memory::data_type::undef;
3259
+ }
3260
+
3212
3261
/// Returns dimensions of the memory descriptor.
3213
3262
///
3214
3263
/// Potentially expensive due to the data copy involved.
@@ -3386,6 +3435,44 @@ struct memory : public handle<dnnl_memory_t> {
3386
3435
reset(result);
3387
3436
}
3388
3437
#else
3438
+ /// Constructs a memory object.
3439
+ ///
3440
+ /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
3441
+ /// object will have the underlying buffer set. In this case, the buffer
3442
+ /// will be initialized as if #dnnl::memory::set_data_handle() had been
3443
+ /// called.
3444
+ ///
3445
+ /// @sa memory::set_data_handle()
3446
+ ///
3447
+ /// @param md Memory descriptor.
3448
+ /// @param aengine Engine to store the data on.
3449
+ /// @param handle Handle of the memory buffer to use.
3450
+ /// - A pointer to the user-allocated buffer. In this case the library
3451
+ /// doesn't own the buffer.
3452
+ /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to
3453
+ /// allocate the buffer for the memory object. In this case the
3454
+ /// library owns the buffer.
3455
+ /// - #DNNL_MEMORY_NONE to create dnnl::memory without an underlying
3456
+ /// buffer.
3457
+ // memory(const desc &md, const engine &aengine, void *handle) {
3458
+ // dnnl_memory_t result;
3459
+ // error::wrap_c_api(
3460
+ // dnnl_memory_create(&result, md.get(), aengine.get(), handle),
3461
+ // "could not create a memory object");
3462
+ // reset(result);
3463
+ // }
3464
+
3465
+ /// Constructs a memory object.
3466
+ ///
3467
+ /// The underlying buffer(s) for the memory will be allocated by the
3468
+ /// library.
3469
+ ///
3470
+ /// @param md Memory descriptor.
3471
+ /// @param aengine Engine to store the data on.
3472
+ // memory(const desc &md, const engine &aengine)
3473
+ // : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {}
3474
+ #endif
3475
+
3389
3476
/// Constructs a memory object.
3390
3477
///
3391
3478
/// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
@@ -3413,15 +3500,11 @@ struct memory : public handle<dnnl_memory_t> {
3413
3500
reset(result);
3414
3501
}
3415
3502
3416
- /// Constructs a memory object.
3417
- ///
3418
3503
/// The underlying buffer for the memory will be allocated by the library.
3419
- ///
3420
3504
/// @param md Memory descriptor.
3421
3505
/// @param aengine Engine to store the data on.
3422
3506
memory(const desc &md, const engine &aengine)
3423
3507
: memory(md, aengine, DNNL_MEMORY_ALLOCATE) {}
3424
- #endif
3425
3508
3426
3509
/// Returns the associated memory descriptor.
3427
3510
desc get_desc() const {
0 commit comments