@@ -552,8 +552,11 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
552
552
if (is_wino_desc () || is_rnn_packed_desc ()) return false ;
553
553
554
554
const int ds = dim_start;
555
- if (stride_start == -1 )
555
+ if (stride_start == -1 ) {
556
556
stride_start = ds;
557
+ } else if (stride_start > ndims ()) {
558
+ stride_start = ndims ();
559
+ }
557
560
const auto &blk = blocking_desc ();
558
561
const auto &r_blk = rhs.blocking_desc ();
559
562
@@ -562,7 +565,7 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
562
565
&& format_kind () == rhs.format_kind ()
563
566
&& IMPLICATION (with_data_type, data_type () == rhs.data_type ())
564
567
&& custom_cpm (dims () + ds, rhs.dims () + ds, ndims () - ds)
565
- && custom_cpm (blk.strides + ds , r_blk.strides + ds , ndims () - ds )
568
+ && custom_cpm (blk.strides + stride_start , r_blk.strides + stride_start , ndims () - stride_start )
566
569
&& blk.inner_nblks == r_blk.inner_nblks
567
570
&& array_cmp (blk.inner_blks , r_blk.inner_blks , blk.inner_nblks )
568
571
&& array_cmp (blk.inner_idxs , r_blk.inner_idxs , blk.inner_nblks )
0 commit comments