File tree 2 files changed +16
-7
lines changed
2 files changed +16
-7
lines changed Original file line number Diff line number Diff line change @@ -73,7 +73,8 @@ struct gemm_desc_t {
73
73
// Simplified accessors that comply to GEMM API
74
74
static transpose_t get_trans (const memory_desc_t &md) {
75
75
if (!md.ndims ) return transpose::notrans; // arbitrary
76
- return md.format_desc .blocking .strides [md.ndims - 1 ] != 1
76
+ return md.dims [md.ndims - 1 ] != 1
77
+ && md.format_desc .blocking .strides [md.ndims - 1 ] != 1
77
78
? transpose::trans
78
79
: transpose::notrans;
79
80
}
@@ -116,9 +117,14 @@ struct gemm_desc_t {
116
117
// This assumes that one of the dimensions has strides 1
117
118
static dnnl_dim_t get_ld (const memory_desc_t &md) {
118
119
auto strides = md.format_desc .blocking .strides ;
119
- assert (strides[md.ndims - 1 ] == 1 || strides[md.ndims - 2 ] == 1 );
120
- return strides[md.ndims - 1 ] != 1 ? strides[md.ndims - 1 ]
121
- : strides[md.ndims - 2 ];
120
+ assert (md.dims [md.ndims - 1 ] == 1 || strides[md.ndims - 1 ] == 1
121
+ || md.dims [md.ndims - 2 ] == 1 || strides[md.ndims - 2 ] == 1 );
122
+ switch (get_trans (md)) {
123
+ case transpose::trans:
124
+ return md.dims [md.ndims - 1 ] > 1 ? strides[md.ndims - 1 ] : 1 ;
125
+ default :
126
+ return md.dims [md.ndims - 2 ] > 1 ? strides[md.ndims - 2 ] : 1 ;
127
+ }
122
128
}
123
129
// Leading dimension of A.
124
130
dnnl_dim_t lda () const { return get_ld (b_desc); }
Original file line number Diff line number Diff line change 1
1
/* ******************************************************************************
2
- * Copyright 2019-2023 Intel Corporation
2
+ * Copyright 2019-2024 Intel Corporation
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
@@ -156,8 +156,11 @@ static inline bool is_md_gemm_compatible_plain_format(
156
156
157
157
if (blk_desc.inner_nblks != 0 ) return false ;
158
158
159
- return (blk_desc.strides [md->ndims - 1 ] == 1 )
160
- || (!is_dst && blk_desc.strides [md->ndims - 2 ] == 1 );
159
+ return (md->dims [md->ndims - 1 ] == 1
160
+ || blk_desc.strides [md->ndims - 1 ] == 1 )
161
+ || (!is_dst
162
+ && (md->dims [md->ndims - 2 ] == 1
163
+ || blk_desc.strides [md->ndims - 2 ] == 1 ));
161
164
}
162
165
163
166
} // namespace impl
You can’t perform that action at this time.
0 commit comments