Skip to content

Commit 0b3b333

Browse files
maxnickazhai219
authored andcommitted
[FIX] Desc similar_to routine use stride cmp mask
1 parent 8f69351 commit 0b3b333

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

src/common/memory_desc_wrapper.hpp

+20-9
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,8 @@ struct memory_desc_wrapper : public c_compatible {
444444
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
445445
/* TODO: revise */
446446
bool similar_to(const memory_desc_wrapper &rhs, bool with_padding = true,
447-
bool with_data_type = true, int dim_start = 0, int stride_start = -1,
448-
bool use_weak_cmp = false, bool check_off0 = false) const;
447+
bool with_data_type = true, int dim_start = 0, bool use_weak_cmp = false,
448+
bool check_off0 = false, uint64_t stride_mask = 0xffffffffffffffff) const;
449449

450450
/** returns true if one memory can be reordered to another */
451451
bool consistent_with(const memory_desc_wrapper &rhs) const;
@@ -617,7 +617,7 @@ struct memory_desc_wrapper : public c_compatible {
617617
};
618618

619619
inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
620-
bool with_padding, bool with_data_type, int dim_start, int stride_start, bool use_weak_cmp, bool check_off0) const {
620+
bool with_padding, bool with_data_type, int dim_start, bool use_weak_cmp, bool check_off0, uint64_t stride_mask) const {
621621
using namespace utils;
622622

623623
if (one_of(format_kind(), format_kind::undef, format_kind::any))
@@ -626,20 +626,31 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
626626
return false;
627627

628628
const int ds = dim_start;
629-
if (stride_start == -1) {
630-
stride_start = ds;
631-
} else if (stride_start > ndims()) {
632-
stride_start = ndims();
633-
}
634629
const auto &blk = blocking_desc();
635630
const auto &r_blk = rhs.blocking_desc();
636631

637632
auto custom_cpm = use_weak_cmp ? array_cmp_weak : array_cmp<dnnl_dim_t>;
633+
auto cmp_strides = [&]() {
634+
if (0xffffffffffffffff == stride_mask) {
635+
return custom_cpm(blk.strides + ds, r_blk.strides + ds, ndims() - ds);
636+
} else {
637+
for (int i = 0; i < ndims(); ++i) {
638+
if (stride_mask & (1 << i)) {
639+
if (blk.strides[i] != r_blk.strides[i]
640+
&& IMPLICATION(use_weak_cmp, (blk.strides[i] != DNNL_RUNTIME_DIM_VAL && r_blk.strides[i] != DNNL_RUNTIME_DIM_VAL))) {
641+
return false;
642+
}
643+
}
644+
}
645+
}
646+
return true;
647+
};
648+
638649
return ndims() == rhs.ndims() && dim_start <= ndims() /* guard */
639650
&& format_kind() == rhs.format_kind()
640651
&& IMPLICATION(with_data_type, data_type() == rhs.data_type())
641652
&& custom_cpm(dims() + ds, rhs.dims() + ds, ndims() - ds)
642-
&& custom_cpm(blk.strides + stride_start, r_blk.strides + stride_start, ndims() - stride_start)
653+
&& cmp_strides()
643654
&& blk.inner_nblks == r_blk.inner_nblks
644655
&& array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
645656
&& array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)

0 commit comments

Comments
 (0)