@@ -377,8 +377,8 @@ struct memory_desc_wrapper : public c_compatible {
377
377
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
378
378
/* TODO: revise */
379
379
bool similar_to (const memory_desc_wrapper &rhs, bool with_padding = true ,
380
- bool with_data_type = true , int dim_start = 0 , int stride_start = - 1 ,
381
- bool use_weak_cmp = false , bool check_off0 = false ) const ;
380
+ bool with_data_type = true , int dim_start = 0 , bool use_weak_cmp = false ,
381
+ bool check_off0 = false , uint64_t stride_mask = 0xffffffffffffffff ) const ;
382
382
383
383
/* * returns true if one memory can be reordered to another */
384
384
bool consistent_with (const memory_desc_wrapper &rhs) const ;
@@ -544,28 +544,39 @@ struct memory_desc_wrapper : public c_compatible {
544
544
};
545
545
546
546
inline bool memory_desc_wrapper::similar_to (const memory_desc_wrapper &rhs,
547
- bool with_padding, bool with_data_type, int dim_start, int stride_start, bool use_weak_cmp, bool check_off0) const {
547
+ bool with_padding, bool with_data_type, int dim_start, bool use_weak_cmp, bool check_off0, uint64_t stride_mask ) const {
548
548
using namespace utils ;
549
549
550
550
if (one_of (format_kind (), format_kind::undef, format_kind::any))
551
551
return false ;
552
552
if (is_wino_desc () || is_rnn_packed_desc ()) return false ;
553
553
554
554
const int ds = dim_start;
555
- if (stride_start == -1 ) {
556
- stride_start = ds;
557
- } else if (stride_start > ndims ()) {
558
- stride_start = ndims ();
559
- }
560
555
const auto &blk = blocking_desc ();
561
556
const auto &r_blk = rhs.blocking_desc ();
562
557
563
558
auto custom_cpm = use_weak_cmp ? array_cmp_weak : array_cmp<dnnl_dim_t >;
559
+ auto cmp_strides = [&]() {
560
+ if (0xffffffffffffffff == stride_mask) {
561
+ return custom_cpm (blk.strides + ds, r_blk.strides + ds, ndims () - ds);
562
+ } else {
563
+ for (int i = 0 ; i < ndims (); ++i) {
564
+ if (stride_mask & (1 << i)) {
565
+ if (blk.strides [i] != r_blk.strides [i]
566
+ && IMPLICATION (use_weak_cmp, (blk.strides [i] != DNNL_RUNTIME_DIM_VAL && r_blk.strides [i] != DNNL_RUNTIME_DIM_VAL))) {
567
+ return false ;
568
+ }
569
+ }
570
+ }
571
+ }
572
+ return true ;
573
+ };
574
+
564
575
return ndims () == rhs.ndims () && dim_start <= ndims () /* guard */
565
576
&& format_kind () == rhs.format_kind ()
566
577
&& IMPLICATION (with_data_type, data_type () == rhs.data_type ())
567
578
&& custom_cpm (dims () + ds, rhs.dims () + ds, ndims () - ds)
568
- && custom_cpm (blk. strides + stride_start, r_blk. strides + stride_start, ndims () - stride_start )
579
+ && cmp_strides ( )
569
580
&& blk.inner_nblks == r_blk.inner_nblks
570
581
&& array_cmp (blk.inner_blks , r_blk.inner_blks , blk.inner_nblks )
571
582
&& array_cmp (blk.inner_idxs , r_blk.inner_idxs , blk.inner_nblks )
0 commit comments