@@ -422,8 +422,8 @@ struct memory_desc_wrapper : public c_compatible {
422
422
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
423
423
/* TODO: revise */
424
424
bool similar_to (const memory_desc_wrapper &rhs, bool with_padding = true ,
425
- bool with_data_type = true , int dim_start = 0 , int stride_start = - 1 ,
426
- bool use_weak_cmp = false , bool check_off0 = false ) const ;
425
+ bool with_data_type = true , int dim_start = 0 , bool use_weak_cmp = false ,
426
+ bool check_off0 = false , uint64_t stride_mask = 0xffffffffffffffff ) const ;
427
427
428
428
/* * returns true if one memory can be reordered to another */
429
429
bool consistent_with (const memory_desc_wrapper &rhs) const ;
@@ -594,28 +594,39 @@ struct memory_desc_wrapper : public c_compatible {
594
594
};
595
595
596
596
inline bool memory_desc_wrapper::similar_to (const memory_desc_wrapper &rhs,
597
- bool with_padding, bool with_data_type, int dim_start, int stride_start, bool use_weak_cmp, bool check_off0) const {
597
+ bool with_padding, bool with_data_type, int dim_start, bool use_weak_cmp, bool check_off0, uint64_t stride_mask ) const {
598
598
using namespace utils ;
599
599
600
600
if (one_of (format_kind (), format_kind::undef, format_kind::any))
601
601
return false ;
602
602
if (is_wino_desc () || is_rnn_packed_desc ()) return false ;
603
603
604
604
const int ds = dim_start;
605
- if (stride_start == -1 ) {
606
- stride_start = ds;
607
- } else if (stride_start > ndims ()) {
608
- stride_start = ndims ();
609
- }
610
605
const auto &blk = blocking_desc ();
611
606
const auto &r_blk = rhs.blocking_desc ();
612
607
613
608
auto custom_cpm = use_weak_cmp ? array_cmp_weak : array_cmp<dnnl_dim_t >;
609
+ auto cmp_strides = [&]() {
610
+ if (0xffffffffffffffff == stride_mask) {
611
+ return custom_cpm (blk.strides + ds, r_blk.strides + ds, ndims () - ds);
612
+ } else {
613
+ for (int i = 0 ; i < ndims (); ++i) {
614
+ if (stride_mask & (1 << i)) {
615
+ if (blk.strides [i] != r_blk.strides [i]
616
+ && IMPLICATION (use_weak_cmp, (blk.strides [i] != DNNL_RUNTIME_DIM_VAL && r_blk.strides [i] != DNNL_RUNTIME_DIM_VAL))) {
617
+ return false ;
618
+ }
619
+ }
620
+ }
621
+ }
622
+ return true ;
623
+ };
624
+
614
625
return ndims () == rhs.ndims () && dim_start <= ndims () /* guard */
615
626
&& format_kind () == rhs.format_kind ()
616
627
&& IMPLICATION (with_data_type, data_type () == rhs.data_type ())
617
628
&& custom_cpm (dims () + ds, rhs.dims () + ds, ndims () - ds)
618
- && custom_cpm (blk. strides + stride_start, r_blk. strides + stride_start, ndims () - stride_start )
629
+ && cmp_strides ( )
619
630
&& blk.inner_nblks == r_blk.inner_nblks
620
631
&& array_cmp (blk.inner_blks , r_blk.inner_blks , blk.inner_nblks )
621
632
&& array_cmp (blk.inner_idxs , r_blk.inner_idxs , blk.inner_nblks )
0 commit comments