@@ -422,8 +422,8 @@ struct memory_desc_wrapper : public c_compatible {
422
422
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
423
423
/* TODO: revise */
424
424
bool similar_to (const memory_desc_wrapper &rhs, bool with_padding = true ,
425
- bool with_data_type = true , int dim_start = 0 , int stride_start = - 1 ,
426
- bool use_weak_cmp = false , bool check_off0 = false ) const ;
425
+ bool with_data_type = true , int dim_start = 0 , bool use_weak_cmp = false ,
426
+ bool check_off0 = false , uint64_t stride_mask = 0xffffffffffffffff ) const ;
427
427
428
428
/* * returns true if one memory can be reordered to another */
429
429
bool consistent_with (const memory_desc_wrapper &rhs) const ;
@@ -589,28 +589,39 @@ struct memory_desc_wrapper : public c_compatible {
589
589
};
590
590
591
591
inline bool memory_desc_wrapper::similar_to (const memory_desc_wrapper &rhs,
592
- bool with_padding, bool with_data_type, int dim_start, int stride_start, bool use_weak_cmp, bool check_off0) const {
592
+ bool with_padding, bool with_data_type, int dim_start, bool use_weak_cmp, bool check_off0, uint64_t stride_mask ) const {
593
593
using namespace utils ;
594
594
595
595
if (one_of (format_kind (), format_kind::undef, format_kind::any))
596
596
return false ;
597
597
if (is_wino_desc () || is_rnn_packed_desc ()) return false ;
598
598
599
599
const int ds = dim_start;
600
- if (stride_start == -1 ) {
601
- stride_start = ds;
602
- } else if (stride_start > ndims ()) {
603
- stride_start = ndims ();
604
- }
605
600
const auto &blk = blocking_desc ();
606
601
const auto &r_blk = rhs.blocking_desc ();
607
602
608
603
auto custom_cpm = use_weak_cmp ? array_cmp_weak : array_cmp<dnnl_dim_t >;
604
+ auto cmp_strides = [&]() {
605
+ if (0xffffffffffffffff == stride_mask) {
606
+ return custom_cpm (blk.strides + ds, r_blk.strides + ds, ndims () - ds);
607
+ } else {
608
+ for (int i = 0 ; i < ndims (); ++i) {
609
+ if (stride_mask & (1 << i)) {
610
+ if (blk.strides [i] != r_blk.strides [i]
611
+ && IMPLICATION (use_weak_cmp, (blk.strides [i] != DNNL_RUNTIME_DIM_VAL && r_blk.strides [i] != DNNL_RUNTIME_DIM_VAL))) {
612
+ return false ;
613
+ }
614
+ }
615
+ }
616
+ }
617
+ return true ;
618
+ };
619
+
609
620
return ndims () == rhs.ndims () && dim_start <= ndims () /* guard */
610
621
&& format_kind () == rhs.format_kind ()
611
622
&& IMPLICATION (with_data_type, data_type () == rhs.data_type ())
612
623
&& custom_cpm (dims () + ds, rhs.dims () + ds, ndims () - ds)
613
- && custom_cpm (blk. strides + stride_start, r_blk. strides + stride_start, ndims () - stride_start )
624
+ && cmp_strides ( )
614
625
&& blk.inner_nblks == r_blk.inner_nblks
615
626
&& array_cmp (blk.inner_blks , r_blk.inner_blks , blk.inner_nblks )
616
627
&& array_cmp (blk.inner_idxs , r_blk.inner_idxs , blk.inner_nblks )
0 commit comments