Skip to content

Commit 57de44a

Browse files
maxnickluweizhou2016
authored andcommitted
[FIX] Desc similar_to routine use stride cmp mask
1 parent 7a238fc commit 57de44a

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

src/common/memory_desc_wrapper.hpp

+20-9
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ struct memory_desc_wrapper : public c_compatible {
378378
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
379379
/* TODO: revise */
380380
bool similar_to(const memory_desc_wrapper &rhs, bool with_padding = true,
381-
bool with_data_type = true, int dim_start = 0, int stride_start = -1,
382-
bool use_weak_cmp = false, bool check_off0 = false) const;
381+
bool with_data_type = true, int dim_start = 0, bool use_weak_cmp = false,
382+
bool check_off0 = false, uint64_t stride_mask = 0xffffffffffffffff) const;
383383

384384
/** returns true if one memory can be reordered to another */
385385
bool consistent_with(const memory_desc_wrapper &rhs) const;
@@ -545,28 +545,39 @@ struct memory_desc_wrapper : public c_compatible {
545545
};
546546

547547
inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
548-
bool with_padding, bool with_data_type, int dim_start, int stride_start, bool use_weak_cmp, bool check_off0) const {
548+
bool with_padding, bool with_data_type, int dim_start, bool use_weak_cmp, bool check_off0, uint64_t stride_mask) const {
549549
using namespace utils;
550550

551551
if (one_of(format_kind(), format_kind::undef, format_kind::any))
552552
return false;
553553
if (is_wino_desc() || is_rnn_packed_desc()) return false;
554554

555555
const int ds = dim_start;
556-
if (stride_start == -1) {
557-
stride_start = ds;
558-
} else if (stride_start > ndims()) {
559-
stride_start = ndims();
560-
}
561556
const auto &blk = blocking_desc();
562557
const auto &r_blk = rhs.blocking_desc();
563558

564559
auto custom_cpm = use_weak_cmp ? array_cmp_weak : array_cmp<dnnl_dim_t>;
560+
auto cmp_strides = [&]() {
561+
if (0xffffffffffffffff == stride_mask) {
562+
return custom_cpm(blk.strides + ds, r_blk.strides + ds, ndims() - ds);
563+
} else {
564+
for (int i = 0; i < ndims(); ++i) {
565+
if (stride_mask & (1 << i)) {
566+
if (blk.strides[i] != r_blk.strides[i]
567+
&& IMPLICATION(use_weak_cmp, (blk.strides[i] != DNNL_RUNTIME_DIM_VAL && r_blk.strides[i] != DNNL_RUNTIME_DIM_VAL))) {
568+
return false;
569+
}
570+
}
571+
}
572+
}
573+
return true;
574+
};
575+
565576
return ndims() == rhs.ndims() && dim_start <= ndims() /* guard */
566577
&& format_kind() == rhs.format_kind()
567578
&& IMPLICATION(with_data_type, data_type() == rhs.data_type())
568579
&& custom_cpm(dims() + ds, rhs.dims() + ds, ndims() - ds)
569-
&& custom_cpm(blk.strides + stride_start, r_blk.strides + stride_start, ndims() - stride_start)
580+
&& cmp_strides()
570581
&& blk.inner_nblks == r_blk.inner_nblks
571582
&& array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
572583
&& array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)

0 commit comments

Comments
 (0)