@@ -792,6 +792,10 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
792
792
if (!everyone_is (data_type::f32, src_d.data_type (), weights_d.data_type (),
793
793
dst_d.data_type ()))
794
794
return status::unimplemented;
795
+ // Big int (> INT_MAX) values are unsupported and jcp fields may overflow
796
+ // TODO: change data type of jcp fields to size_t
797
+ VDISPATCH_CONV_IC (!has_large_size (cd, src_d, weights_d, dst_d),
798
+ VERBOSE_BAD_PARAM, " Large size is not supported" );
795
799
796
800
const int regs = 28 ;
797
801
const bool with_groups = weights_d.ndims () == src_d.ndims () + 1 ;
@@ -823,13 +827,6 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
823
827
jcp.stride_h = (ndims == 3 ) ? 1 : cd.strides [ndims - 4 ];
824
828
jcp.stride_w = cd.strides [ndims - 3 ];
825
829
826
- // Big int (> INT_MAX) values are unsupported and jcp fields may overflow
827
- // TODO: change data type of jcp fields to size_t
828
- VDISPATCH_CONV_IC (!((ndims == 5 && cd.dilates [ndims - 5 ] > INT_MAX)
829
- || (ndims >= 4 && cd.dilates [ndims - 4 ] > INT_MAX)
830
- || (cd.dilates [ndims - 3 ] > INT_MAX)),
831
- VERBOSE_BAD_PARAM, " dilates" );
832
-
833
830
jcp.dilate_d = (ndims == 5 ) ? cd.dilates [0 ] : 0 ;
834
831
jcp.dilate_h = (ndims == 3 ) ? 0 : cd.dilates [ndims - 4 ];
835
832
jcp.dilate_w = cd.dilates [ndims - 3 ];
@@ -1859,6 +1856,10 @@ status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
1859
1856
if (!everyone_is (data_type::f32, diff_dst_d.data_type (),
1860
1857
weights_d.data_type (), diff_src_d.data_type ()))
1861
1858
return status::unimplemented;
1859
+ // Big int (> INT_MAX) values are unsupported and jcp fields may overflow
1860
+ // TODO: change data type of jcp fields to size_t
1861
+ VDISPATCH_CONV_IC (!has_large_size (cd, diff_src_d, weights_d, diff_dst_d),
1862
+ VERBOSE_BAD_PARAM, " Large size is not supported" );
1862
1863
1863
1864
const bool with_groups = weights_d.ndims () == diff_src_d.ndims () + 1 ;
1864
1865
int ndims = diff_src_d.ndims ();
@@ -3906,6 +3907,10 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
3906
3907
if (!utils::everyone_is (data_type::f32, src_d.data_type (),
3907
3908
diff_weights_d.data_type (), diff_dst_d.data_type ()))
3908
3909
return status::unimplemented;
3910
+ // Big int (> INT_MAX) values are unsupported and jcp fields may overflow
3911
+ // TODO: change data type of jcp fields to size_t
3912
+ VDISPATCH_CONV_IC (!has_large_size (cd, src_d, diff_weights_d, diff_dst_d),
3913
+ VERBOSE_BAD_PARAM, " Large size is not supported" );
3909
3914
3910
3915
const bool with_groups = diff_weights_d.ndims () == src_d.ndims () + 1 ;
3911
3916
int ndims = src_d.ndims ();
0 commit comments