Skip to content

Commit 8278c83

Browse files
maxnickluweizhou2016
authored andcommitted
[FIX] Desc similar_to routine use stride cmp mask
1 parent 19ce5f6 commit 8278c83

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

src/common/memory_desc_wrapper.hpp

+20-9
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,8 @@ struct memory_desc_wrapper : public c_compatible {
422422
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
423423
/* TODO: revise */
424424
bool similar_to(const memory_desc_wrapper &rhs, bool with_padding = true,
425-
bool with_data_type = true, int dim_start = 0, int stride_start = -1,
426-
bool use_weak_cmp = false, bool check_off0 = false) const;
425+
bool with_data_type = true, int dim_start = 0, bool use_weak_cmp = false,
426+
bool check_off0 = false, uint64_t stride_mask = 0xffffffffffffffff) const;
427427

428428
/** returns true if one memory can be reordered to another */
429429
bool consistent_with(const memory_desc_wrapper &rhs) const;
@@ -594,28 +594,39 @@ struct memory_desc_wrapper : public c_compatible {
594594
};
595595

596596
inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
597-
bool with_padding, bool with_data_type, int dim_start, int stride_start, bool use_weak_cmp, bool check_off0) const {
597+
bool with_padding, bool with_data_type, int dim_start, bool use_weak_cmp, bool check_off0, uint64_t stride_mask) const {
598598
using namespace utils;
599599

600600
if (one_of(format_kind(), format_kind::undef, format_kind::any))
601601
return false;
602602
if (is_wino_desc() || is_rnn_packed_desc()) return false;
603603

604604
const int ds = dim_start;
605-
if (stride_start == -1) {
606-
stride_start = ds;
607-
} else if (stride_start > ndims()) {
608-
stride_start = ndims();
609-
}
610605
const auto &blk = blocking_desc();
611606
const auto &r_blk = rhs.blocking_desc();
612607

613608
auto custom_cpm = use_weak_cmp ? array_cmp_weak : array_cmp<dnnl_dim_t>;
609+
auto cmp_strides = [&]() {
610+
if (0xffffffffffffffff == stride_mask) {
611+
return custom_cpm(blk.strides + ds, r_blk.strides + ds, ndims() - ds);
612+
} else {
613+
for (int i = 0; i < ndims(); ++i) {
614+
if (stride_mask & (1 << i)) {
615+
if (blk.strides[i] != r_blk.strides[i]
616+
&& IMPLICATION(use_weak_cmp, (blk.strides[i] != DNNL_RUNTIME_DIM_VAL && r_blk.strides[i] != DNNL_RUNTIME_DIM_VAL))) {
617+
return false;
618+
}
619+
}
620+
}
621+
}
622+
return true;
623+
};
624+
614625
return ndims() == rhs.ndims() && dim_start <= ndims() /* guard */
615626
&& format_kind() == rhs.format_kind()
616627
&& IMPLICATION(with_data_type, data_type() == rhs.data_type())
617628
&& custom_cpm(dims() + ds, rhs.dims() + ds, ndims() - ds)
618-
&& custom_cpm(blk.strides + stride_start, r_blk.strides + stride_start, ndims() - stride_start)
629+
&& cmp_strides()
619630
&& blk.inner_nblks == r_blk.inner_nblks
620631
&& array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
621632
&& array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)

0 commit comments

Comments
 (0)