@@ -602,8 +602,11 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
602
602
if (is_wino_desc () || is_rnn_packed_desc ()) return false ;
603
603
604
604
const int ds = dim_start;
605
- if (stride_start == -1 )
605
+ if (stride_start == -1 ) {
606
606
stride_start = ds;
607
+ } else if (stride_start > ndims ()) {
608
+ stride_start = ndims ();
609
+ }
607
610
const auto &blk = blocking_desc ();
608
611
const auto &r_blk = rhs.blocking_desc ();
609
612
@@ -612,7 +615,7 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
612
615
&& format_kind () == rhs.format_kind ()
613
616
&& IMPLICATION (with_data_type, data_type () == rhs.data_type ())
614
617
&& custom_cpm (dims () + ds, rhs.dims () + ds, ndims () - ds)
615
- && custom_cpm (blk.strides + ds , r_blk.strides + ds , ndims () - ds )
618
+ && custom_cpm (blk.strides + stride_start , r_blk.strides + stride_start , ndims () - stride_start )
616
619
&& blk.inner_nblks == r_blk.inner_nblks
617
620
&& array_cmp (blk.inner_blks , r_blk.inner_blks , blk.inner_nblks )
618
621
&& array_cmp (blk.inner_idxs , r_blk.inner_idxs , blk.inner_nblks )
0 commit comments