Skip to content

Commit 486fd32

Browse files
maxnickluweizhou2016
authored andcommitted
[FIX] Desc similar_to routine use stride cmp mask
1 parent 0d3afe5 commit 486fd32

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

src/common/memory_desc_wrapper.hpp

+20-9
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,8 @@ struct memory_desc_wrapper : public c_compatible {
422422
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
423423
/* TODO: revise */
424424
bool similar_to(const memory_desc_wrapper &rhs, bool with_padding = true,
425-
bool with_data_type = true, int dim_start = 0, int stride_start = -1,
426-
bool use_weak_cmp = false, bool check_off0 = false) const;
425+
bool with_data_type = true, int dim_start = 0, bool use_weak_cmp = false,
426+
bool check_off0 = false, uint64_t stride_mask = 0xffffffffffffffff) const;
427427

428428
/** returns true if one memory can be reordered to another */
429429
bool consistent_with(const memory_desc_wrapper &rhs) const;
@@ -589,28 +589,39 @@ struct memory_desc_wrapper : public c_compatible {
589589
};
590590

591591
inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
592-
bool with_padding, bool with_data_type, int dim_start, int stride_start, bool use_weak_cmp, bool check_off0) const {
592+
bool with_padding, bool with_data_type, int dim_start, bool use_weak_cmp, bool check_off0, uint64_t stride_mask) const {
593593
using namespace utils;
594594

595595
if (one_of(format_kind(), format_kind::undef, format_kind::any))
596596
return false;
597597
if (is_wino_desc() || is_rnn_packed_desc()) return false;
598598

599599
const int ds = dim_start;
600-
if (stride_start == -1) {
601-
stride_start = ds;
602-
} else if (stride_start > ndims()) {
603-
stride_start = ndims();
604-
}
605600
const auto &blk = blocking_desc();
606601
const auto &r_blk = rhs.blocking_desc();
607602

608603
auto custom_cpm = use_weak_cmp ? array_cmp_weak : array_cmp<dnnl_dim_t>;
604+
auto cmp_strides = [&]() {
605+
if (0xffffffffffffffff == stride_mask) {
606+
return custom_cpm(blk.strides + ds, r_blk.strides + ds, ndims() - ds);
607+
} else {
608+
for (int i = 0; i < ndims(); ++i) {
609+
if (stride_mask & (1 << i)) {
610+
if (blk.strides[i] != r_blk.strides[i]
611+
&& IMPLICATION(use_weak_cmp, (blk.strides[i] != DNNL_RUNTIME_DIM_VAL && r_blk.strides[i] != DNNL_RUNTIME_DIM_VAL))) {
612+
return false;
613+
}
614+
}
615+
}
616+
}
617+
return true;
618+
};
619+
609620
return ndims() == rhs.ndims() && dim_start <= ndims() /* guard */
610621
&& format_kind() == rhs.format_kind()
611622
&& IMPLICATION(with_data_type, data_type() == rhs.data_type())
612623
&& custom_cpm(dims() + ds, rhs.dims() + ds, ndims() - ds)
613-
&& custom_cpm(blk.strides + stride_start, r_blk.strides + stride_start, ndims() - stride_start)
624+
&& cmp_strides()
614625
&& blk.inner_nblks == r_blk.inner_nblks
615626
&& array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
616627
&& array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)

0 commit comments

Comments
 (0)