Skip to content

Commit 99dcc2d

Browse files
maxnickluweizhou2016
authored andcommitted
[FIX] Desc similar_to routine use stride cmp mask
1 parent 44e429e commit 99dcc2d

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

src/common/memory_desc_wrapper.hpp

+20-9
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,8 @@ struct memory_desc_wrapper : public c_compatible {
377377
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
378378
/* TODO: revise */
379379
bool similar_to(const memory_desc_wrapper &rhs, bool with_padding = true,
380-
bool with_data_type = true, int dim_start = 0, int stride_start = -1,
381-
bool use_weak_cmp = false, bool check_off0 = false) const;
380+
bool with_data_type = true, int dim_start = 0, bool use_weak_cmp = false,
381+
bool check_off0 = false, uint64_t stride_mask = 0xffffffffffffffff) const;
382382

383383
/** returns true if one memory can be reordered to another */
384384
bool consistent_with(const memory_desc_wrapper &rhs) const;
@@ -544,28 +544,39 @@ struct memory_desc_wrapper : public c_compatible {
544544
};
545545

546546
inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
547-
bool with_padding, bool with_data_type, int dim_start, int stride_start, bool use_weak_cmp, bool check_off0) const {
547+
bool with_padding, bool with_data_type, int dim_start, bool use_weak_cmp, bool check_off0, uint64_t stride_mask) const {
548548
using namespace utils;
549549

550550
if (one_of(format_kind(), format_kind::undef, format_kind::any))
551551
return false;
552552
if (is_wino_desc() || is_rnn_packed_desc()) return false;
553553

554554
const int ds = dim_start;
555-
if (stride_start == -1) {
556-
stride_start = ds;
557-
} else if (stride_start > ndims()) {
558-
stride_start = ndims();
559-
}
560555
const auto &blk = blocking_desc();
561556
const auto &r_blk = rhs.blocking_desc();
562557

563558
auto custom_cpm = use_weak_cmp ? array_cmp_weak : array_cmp<dnnl_dim_t>;
559+
auto cmp_strides = [&]() {
560+
if (0xffffffffffffffff == stride_mask) {
561+
return custom_cpm(blk.strides + ds, r_blk.strides + ds, ndims() - ds);
562+
} else {
563+
for (int i = 0; i < ndims(); ++i) {
564+
if (stride_mask & (1 << i)) {
565+
if (blk.strides[i] != r_blk.strides[i]
566+
&& IMPLICATION(use_weak_cmp, (blk.strides[i] != DNNL_RUNTIME_DIM_VAL && r_blk.strides[i] != DNNL_RUNTIME_DIM_VAL))) {
567+
return false;
568+
}
569+
}
570+
}
571+
}
572+
return true;
573+
};
574+
564575
return ndims() == rhs.ndims() && dim_start <= ndims() /* guard */
565576
&& format_kind() == rhs.format_kind()
566577
&& IMPLICATION(with_data_type, data_type() == rhs.data_type())
567578
&& custom_cpm(dims() + ds, rhs.dims() + ds, ndims() - ds)
568-
&& custom_cpm(blk.strides + stride_start, r_blk.strides + stride_start, ndims() - stride_start)
579+
&& cmp_strides()
569580
&& blk.inner_nblks == r_blk.inner_nblks
570581
&& array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
571582
&& array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)

0 commit comments

Comments
 (0)