@@ -444,8 +444,8 @@ struct memory_desc_wrapper : public c_compatible {
444
444
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
445
445
/* TODO: revise */
446
446
bool similar_to (const memory_desc_wrapper &rhs, bool with_padding = true ,
447
- bool with_data_type = true , int dim_start = 0 , int stride_start = - 1 ,
448
- bool use_weak_cmp = false , bool check_off0 = false ) const ;
447
+ bool with_data_type = true , int dim_start = 0 , bool use_weak_cmp = false ,
448
+ bool check_off0 = false , uint64_t stride_mask = 0xffffffffffffffff ) const ;
449
449
450
450
/* * returns true if one memory can be reordered to another */
451
451
bool consistent_with (const memory_desc_wrapper &rhs) const ;
@@ -617,7 +617,7 @@ struct memory_desc_wrapper : public c_compatible {
617
617
};
618
618
619
619
inline bool memory_desc_wrapper::similar_to (const memory_desc_wrapper &rhs,
620
- bool with_padding, bool with_data_type, int dim_start, int stride_start, bool use_weak_cmp, bool check_off0) const {
620
+ bool with_padding, bool with_data_type, int dim_start, bool use_weak_cmp, bool check_off0, uint64_t stride_mask ) const {
621
621
using namespace utils ;
622
622
623
623
if (one_of (format_kind (), format_kind::undef, format_kind::any))
@@ -626,20 +626,31 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
626
626
return false ;
627
627
628
628
const int ds = dim_start;
629
- if (stride_start == -1 ) {
630
- stride_start = ds;
631
- } else if (stride_start > ndims ()) {
632
- stride_start = ndims ();
633
- }
634
629
const auto &blk = blocking_desc ();
635
630
const auto &r_blk = rhs.blocking_desc ();
636
631
637
632
auto custom_cpm = use_weak_cmp ? array_cmp_weak : array_cmp<dnnl_dim_t >;
633
+ auto cmp_strides = [&]() {
634
+ if (0xffffffffffffffff == stride_mask) {
635
+ return custom_cpm (blk.strides + ds, r_blk.strides + ds, ndims () - ds);
636
+ } else {
637
+ for (int i = 0 ; i < ndims (); ++i) {
638
+ if (stride_mask & (1 << i)) {
639
+ if (blk.strides [i] != r_blk.strides [i]
640
+ && IMPLICATION (use_weak_cmp, (blk.strides [i] != DNNL_RUNTIME_DIM_VAL && r_blk.strides [i] != DNNL_RUNTIME_DIM_VAL))) {
641
+ return false ;
642
+ }
643
+ }
644
+ }
645
+ }
646
+ return true ;
647
+ };
648
+
638
649
return ndims () == rhs.ndims () && dim_start <= ndims () /* guard */
639
650
&& format_kind () == rhs.format_kind ()
640
651
&& IMPLICATION (with_data_type, data_type () == rhs.data_type ())
641
652
&& custom_cpm (dims () + ds, rhs.dims () + ds, ndims () - ds)
642
- && custom_cpm (blk. strides + stride_start, r_blk. strides + stride_start, ndims () - stride_start )
653
+ && cmp_strides ( )
643
654
&& blk.inner_nblks == r_blk.inner_nblks
644
655
&& array_cmp (blk.inner_blks , r_blk.inner_blks , blk.inner_nblks )
645
656
&& array_cmp (blk.inner_idxs , r_blk.inner_idxs , blk.inner_nblks )
0 commit comments