@@ -378,8 +378,8 @@ struct memory_desc_wrapper : public c_compatible {
378
378
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
379
379
/* TODO: revise */
380
380
bool similar_to (const memory_desc_wrapper &rhs, bool with_padding = true ,
381
- bool with_data_type = true , int dim_start = 0 , int stride_start = - 1 ,
382
- bool use_weak_cmp = false , bool check_off0 = false ) const ;
381
+ bool with_data_type = true , int dim_start = 0 , bool use_weak_cmp = false ,
382
+ bool check_off0 = false , uint64_t stride_mask = 0xffffffffffffffff ) const ;
383
383
384
384
/* * returns true if one memory can be reordered to another */
385
385
bool consistent_with (const memory_desc_wrapper &rhs) const ;
@@ -545,28 +545,39 @@ struct memory_desc_wrapper : public c_compatible {
545
545
};
546
546
547
547
inline bool memory_desc_wrapper::similar_to (const memory_desc_wrapper &rhs,
548
- bool with_padding, bool with_data_type, int dim_start, int stride_start, bool use_weak_cmp, bool check_off0) const {
548
+ bool with_padding, bool with_data_type, int dim_start, bool use_weak_cmp, bool check_off0, uint64_t stride_mask ) const {
549
549
using namespace utils ;
550
550
551
551
if (one_of (format_kind (), format_kind::undef, format_kind::any))
552
552
return false ;
553
553
if (is_wino_desc () || is_rnn_packed_desc ()) return false ;
554
554
555
555
const int ds = dim_start;
556
- if (stride_start == -1 ) {
557
- stride_start = ds;
558
- } else if (stride_start > ndims ()) {
559
- stride_start = ndims ();
560
- }
561
556
const auto &blk = blocking_desc ();
562
557
const auto &r_blk = rhs.blocking_desc ();
563
558
564
559
auto custom_cpm = use_weak_cmp ? array_cmp_weak : array_cmp<dnnl_dim_t >;
560
+ auto cmp_strides = [&]() {
561
+ if (0xffffffffffffffff == stride_mask) {
562
+ return custom_cpm (blk.strides + ds, r_blk.strides + ds, ndims () - ds);
563
+ } else {
564
+ for (int i = 0 ; i < ndims (); ++i) {
565
+ if (stride_mask & (1 << i)) {
566
+ if (blk.strides [i] != r_blk.strides [i]
567
+ && IMPLICATION (use_weak_cmp, (blk.strides [i] != DNNL_RUNTIME_DIM_VAL && r_blk.strides [i] != DNNL_RUNTIME_DIM_VAL))) {
568
+ return false ;
569
+ }
570
+ }
571
+ }
572
+ }
573
+ return true ;
574
+ };
575
+
565
576
return ndims () == rhs.ndims () && dim_start <= ndims () /* guard */
566
577
&& format_kind () == rhs.format_kind ()
567
578
&& IMPLICATION (with_data_type, data_type () == rhs.data_type ())
568
579
&& custom_cpm (dims () + ds, rhs.dims () + ds, ndims () - ds)
569
- && custom_cpm (blk. strides + stride_start, r_blk. strides + stride_start, ndims () - stride_start )
580
+ && cmp_strides ( )
570
581
&& blk.inner_nblks == r_blk.inner_nblks
571
582
&& array_cmp (blk.inner_blks , r_blk.inner_blks , blk.inner_nblks )
572
583
&& array_cmp (blk.inner_idxs , r_blk.inner_idxs , blk.inner_nblks )
0 commit comments