Skip to content

Commit 0306ffd

Browse files
rjourslerkarturov
authored andcommitted
common: fix undefined use of memory descriptor strides
1 parent 05303ea commit 0306ffd

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

src/common/gemm_types.hpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ struct gemm_desc_t {
7373
// Simplified accessors that comply to GEMM API
7474
static transpose_t get_trans(const memory_desc_t &md) {
7575
if (!md.ndims) return transpose::notrans; // arbitrary
76-
return md.format_desc.blocking.strides[md.ndims - 1] != 1
76+
return md.dims[md.ndims - 1] != 1
77+
&& md.format_desc.blocking.strides[md.ndims - 1] != 1
7778
? transpose::trans
7879
: transpose::notrans;
7980
}
@@ -116,9 +117,14 @@ struct gemm_desc_t {
116117
// This assumes that one of the dimensions has strides 1
117118
static dnnl_dim_t get_ld(const memory_desc_t &md) {
118119
auto strides = md.format_desc.blocking.strides;
119-
assert(strides[md.ndims - 1] == 1 || strides[md.ndims - 2] == 1);
120-
return strides[md.ndims - 1] != 1 ? strides[md.ndims - 1]
121-
: strides[md.ndims - 2];
120+
assert(md.dims[md.ndims - 1] == 1 || strides[md.ndims - 1] == 1
121+
|| md.dims[md.ndims - 2] == 1 || strides[md.ndims - 2] == 1);
122+
switch (get_trans(md)) {
123+
case transpose::trans:
124+
return md.dims[md.ndims - 1] > 1 ? strides[md.ndims - 1] : 1;
125+
default:
126+
return md.dims[md.ndims - 2] > 1 ? strides[md.ndims - 2] : 1;
127+
}
122128
}
123129
// Leading dimension of A.
124130
dnnl_dim_t lda() const { return get_ld(b_desc); }

src/common/gemm_utils.hpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2019-2023 Intel Corporation
2+
* Copyright 2019-2024 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -156,8 +156,11 @@ static inline bool is_md_gemm_compatible_plain_format(
156156

157157
if (blk_desc.inner_nblks != 0) return false;
158158

159-
return (blk_desc.strides[md->ndims - 1] == 1)
160-
|| (!is_dst && blk_desc.strides[md->ndims - 2] == 1);
159+
return (md->dims[md->ndims - 1] == 1
160+
|| blk_desc.strides[md->ndims - 1] == 1)
161+
|| (!is_dst
162+
&& (md->dims[md->ndims - 2] == 1
163+
|| blk_desc.strides[md->ndims - 2] == 1));
161164
}
162165

163166
} // namespace impl

0 commit comments

Comments
 (0)