@@ -597,8 +597,11 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
597
597
if (is_wino_desc () || is_rnn_packed_desc ()) return false ;
598
598
599
599
const int ds = dim_start;
600
- if (stride_start == -1 )
600
+ if (stride_start == -1 ) {
601
601
stride_start = ds;
602
+ } else if (stride_start > ndims ()) {
603
+ stride_start = ndims ();
604
+ }
602
605
const auto &blk = blocking_desc ();
603
606
const auto &r_blk = rhs.blocking_desc ();
604
607
@@ -607,7 +610,7 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
607
610
&& format_kind () == rhs.format_kind ()
608
611
&& IMPLICATION (with_data_type, data_type () == rhs.data_type ())
609
612
&& custom_cpm (dims () + ds, rhs.dims () + ds, ndims () - ds)
610
- && custom_cpm (blk.strides + ds , r_blk.strides + ds , ndims () - ds )
613
+ && custom_cpm (blk.strides + stride_start , r_blk.strides + stride_start , ndims () - stride_start )
611
614
&& blk.inner_nblks == r_blk.inner_nblks
612
615
&& array_cmp (blk.inner_blks , r_blk.inner_blks , blk.inner_nblks )
613
616
&& array_cmp (blk.inner_idxs , r_blk.inner_idxs , blk.inner_nblks )
0 commit comments