@@ -553,8 +553,11 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
553
553
if (is_wino_desc () || is_rnn_packed_desc ()) return false ;
554
554
555
555
const int ds = dim_start;
556
- if (stride_start == -1 )
556
+ if (stride_start == -1 ) {
557
557
stride_start = ds;
558
+ } else if (stride_start > ndims ()) {
559
+ stride_start = ndims ();
560
+ }
558
561
const auto &blk = blocking_desc ();
559
562
const auto &r_blk = rhs.blocking_desc ();
560
563
@@ -563,7 +566,7 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
563
566
&& format_kind () == rhs.format_kind ()
564
567
&& IMPLICATION (with_data_type, data_type () == rhs.data_type ())
565
568
&& custom_cpm (dims () + ds, rhs.dims () + ds, ndims () - ds)
566
- && custom_cpm (blk.strides + ds , r_blk.strides + ds , ndims () - ds )
569
+ && custom_cpm (blk.strides + stride_start , r_blk.strides + stride_start , ndims () - stride_start )
567
570
&& blk.inner_nblks == r_blk.inner_nblks
568
571
&& array_cmp (blk.inner_blks , r_blk.inner_blks , blk.inner_nblks )
569
572
&& array_cmp (blk.inner_idxs , r_blk.inner_idxs , blk.inner_nblks )
0 commit comments