Skip to content

Commit 317e45e

Browse files
authored
cpu: conv: disable huge shapes to avoid potential overflows (#2827)
1 parent 712dfe1 commit 317e45e

19 files changed

+161
-9
lines changed

src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,11 @@ status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
690690
if (!mayiuse(avx)) return status::unimplemented;
691691
jcp.isa = mayiuse(avx2) ? avx2 : avx;
692692

693+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
694+
// TODO: change data type of jcp fields to size_t
695+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
696+
VERBOSE_BAD_PARAM, "Large size is not supported");
697+
693698
// TODO (Roma): this code is duplicated from the generic kernel; maybe the
694699
// configuration struct could do some stuff below
695700
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;

src/cpu/x64/jit_avx2_conv_kernel_f32.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,12 @@ status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
588588
const primitive_attr_t &attr) {
589589
// disabling verbose dispatch messages for unsupported isa for better readability
590590
if (!mayiuse(avx)) return status::unimplemented;
591+
592+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
593+
// TODO: change data type of jcp fields to size_t
594+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
595+
VERBOSE_BAD_PARAM, "Large size is not supported");
596+
591597
jcp.isa = mayiuse(avx2) ? avx2 : avx;
592598

593599
jcp.nthr = dnnl_get_max_threads();
@@ -1112,6 +1118,11 @@ status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp,
11121118
// disabling verbose dispatch messages for unsupported isa for better readability
11131119
if (!mayiuse(avx2)) return status::unimplemented;
11141120

1121+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
1122+
// TODO: change data type of jcp fields to size_t
1123+
VDISPATCH_CONV_IC(!has_large_size(cd, diff_src_d, weights_d, diff_dst_d),
1124+
VERBOSE_BAD_PARAM, "Large size is not supported");
1125+
11151126
jcp.nthr = dnnl_get_max_threads();
11161127

11171128
const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
@@ -1333,6 +1344,11 @@ status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp,
13331344
// disabling verbose dispatch messages for unsupported isa for better readability
13341345
if (!mayiuse(avx2)) return status::unimplemented;
13351346

1347+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
1348+
// TODO: change data type of jcp fields to size_t
1349+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, diff_weights_d, diff_dst_d),
1350+
VERBOSE_BAD_PARAM, "Large size is not supported");
1351+
13361352
const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
13371353
int ndims = src_d.ndims();
13381354
jcp.ndims = ndims;

src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,11 @@ status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp,
566566
dst_d.data_type()))
567567
return status::unimplemented;
568568

569+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
570+
// TODO: change data type of jcp fields to size_t
571+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
572+
VERBOSE_BAD_PARAM, "Large size is not supported");
573+
569574
jcp.nthr = nthreads;
570575

571576
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;

src/cpu/x64/jit_avx512_common_conv_kernel.cpp

+12-7
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,10 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
792792
if (!everyone_is(data_type::f32, src_d.data_type(), weights_d.data_type(),
793793
dst_d.data_type()))
794794
return status::unimplemented;
795+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
796+
// TODO: change data type of jcp fields to size_t
797+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
798+
VERBOSE_BAD_PARAM, "Large size is not supported");
795799

796800
const int regs = 28;
797801
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
@@ -823,13 +827,6 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
823827
jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
824828
jcp.stride_w = cd.strides[ndims - 3];
825829

826-
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
827-
// TODO: change data type of jcp fields to size_t
828-
VDISPATCH_CONV_IC(!((ndims == 5 && cd.dilates[ndims - 5] > INT_MAX)
829-
|| (ndims >= 4 && cd.dilates[ndims - 4] > INT_MAX)
830-
|| (cd.dilates[ndims - 3] > INT_MAX)),
831-
VERBOSE_BAD_PARAM, "dilates");
832-
833830
jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
834831
jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
835832
jcp.dilate_w = cd.dilates[ndims - 3];
@@ -1859,6 +1856,10 @@ status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
18591856
if (!everyone_is(data_type::f32, diff_dst_d.data_type(),
18601857
weights_d.data_type(), diff_src_d.data_type()))
18611858
return status::unimplemented;
1859+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
1860+
// TODO: change data type of jcp fields to size_t
1861+
VDISPATCH_CONV_IC(!has_large_size(cd, diff_src_d, weights_d, diff_dst_d),
1862+
VERBOSE_BAD_PARAM, "Large size is not supported");
18621863

18631864
const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
18641865
int ndims = diff_src_d.ndims();
@@ -3906,6 +3907,10 @@ status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
39063907
if (!utils::everyone_is(data_type::f32, src_d.data_type(),
39073908
diff_weights_d.data_type(), diff_dst_d.data_type()))
39083909
return status::unimplemented;
3910+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
3911+
// TODO: change data type of jcp fields to size_t
3912+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, diff_weights_d, diff_dst_d),
3913+
VERBOSE_BAD_PARAM, "Large size is not supported");
39093914

39103915
const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
39113916
int ndims = src_d.ndims();

src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*******************************************************************************/
1616

1717
#include "common/c_types_map.hpp"
18+
#include "common/convolution_pd.hpp"
1819
#include "common/memory_tracking.hpp"
1920
#include "common/nstl.hpp"
2021
#include "common/type_helpers.hpp"
@@ -952,6 +953,11 @@ status_t jit_avx512_core_amx_1x1_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp,
952953
bool is_1d = ndims == 3;
953954
bool is_3d = ndims == 5;
954955

956+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
957+
// TODO: change data type of jcp fields to size_t
958+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
959+
VERBOSE_BAD_PARAM, "Large size is not supported");
960+
955961
const bool is_bf16_convolution
956962
= everyone_is(true, src_d.data_type() == data_type::bf16,
957963
weights_d.data_type() == data_type::bf16,

src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -2263,6 +2263,11 @@ status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp,
22632263
bool is_1d = ndims == 3;
22642264
bool is_3d = ndims == 5;
22652265

2266+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
2267+
// TODO: change data type of jcp fields to size_t
2268+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
2269+
VERBOSE_BAD_PARAM, "Large size is not supported");
2270+
22662271
const bool is_bf16_convolution
22672272
= everyone_is(true, src_d.data_type() == data_type::bf16,
22682273
weights_d.data_type() == data_type::bf16,
@@ -3718,6 +3723,11 @@ status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
37183723
const memory_desc_wrapper diff_dst_d(&diff_dst_md);
37193724
const memory_desc_wrapper bias_d(bias_md);
37203725

3726+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
3727+
// TODO: change data type of jcp fields to size_t
3728+
VDISPATCH_CONV_IC(!has_large_size(cd, diff_src_d, weights_d, diff_dst_d),
3729+
VERBOSE_BAD_PARAM, "Large size is not supported");
3730+
37213731
const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
37223732
int ndims = diff_src_d.ndims();
37233733
bool is_1d = ndims == 3;
@@ -5167,6 +5177,10 @@ status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_conf(
51675177
const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
51685178
int ndims = src_d.ndims();
51695179

5180+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
5181+
// TODO: change data type of jcp fields to size_t
5182+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, diff_weights_d, diff_dst_d),
5183+
VERBOSE_BAD_PARAM, "Large size is not supported");
51705184
VDISPATCH_CONV_IC(mayiuse(avx512_core_amx), VERBOSE_UNSUPPORTED_ISA);
51715185
jcp.isa = avx512_core_amx;
51725186

src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,11 @@ status_t jit_avx512_core_bf16_1x1_conv_kernel::init_conf(
11961196
const int simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
11971197
const int ndims = src_d.ndims();
11981198

1199+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
1200+
// TODO: change data type of jcp fields to size_t
1201+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
1202+
VERBOSE_BAD_PARAM, "Large size is not supported");
1203+
11991204
jcp.nthr = nthreads;
12001205
jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
12011206
: bf16_emulation_t::get_isa();

src/cpu/x64/jit_avx512_core_bf16_conv_kernel.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,11 @@ status_t jit_avx512_core_bf16_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
829829
const memory_desc_wrapper dst_d(&dst_md);
830830
const memory_desc_wrapper bias_d(&bias_md);
831831

832+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
833+
// TODO: change data type of jcp fields to size_t
834+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
835+
VERBOSE_BAD_PARAM, "Large size is not supported");
836+
832837
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
833838
int ndims = src_d.ndims();
834839

@@ -1540,6 +1545,11 @@ status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(jit_conv_conf_t &jcp,
15401545
const memory_desc_wrapper weights_d(&weights_md);
15411546
const memory_desc_wrapper diff_dst_d(&diff_dst_md);
15421547

1548+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
1549+
// TODO: change data type of jcp fields to size_t
1550+
VDISPATCH_CONV_IC(!has_large_size(cd, diff_src_d, weights_d, diff_dst_d),
1551+
VERBOSE_BAD_PARAM, "Large size is not supported");
1552+
15431553
const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
15441554
int ndims = diff_src_d.ndims();
15451555

@@ -4134,6 +4144,11 @@ status_t jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf(
41344144
const memory_desc_wrapper diff_dst_d(&diff_dst_md);
41354145
const memory_desc_wrapper diff_bias_d(&diff_bias_md);
41364146

4147+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
4148+
// TODO: change data type of jcp fields to size_t
4149+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, diff_weights_d, diff_dst_d),
4150+
VERBOSE_BAD_PARAM, "Large size is not supported");
4151+
41374152
const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
41384153
int ndims = src_d.ndims();
41394154

src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <assert.h>
1818

1919
#include "common/c_types_map.hpp"
20+
#include "common/convolution_pd.hpp"
2021
#include "common/memory.hpp"
2122
#include "common/memory_tracking.hpp"
2223
#include "common/nstl.hpp"
@@ -838,6 +839,11 @@ status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf(
838839
const memory_desc_wrapper dst_d(&dst_md);
839840
const memory_desc_wrapper bias_d(&bias_md);
840841

842+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
843+
// TODO: change data type of jcp fields to size_t
844+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
845+
VERBOSE_BAD_PARAM, "Large size is not supported");
846+
841847
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
842848
if (!one_of(src_d.data_type(), data_type::u8, data_type::s8)
843849
|| weights_d.data_type() != data_type::s8

src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*******************************************************************************/
1616

1717
#include "common/c_types_map.hpp"
18+
#include "common/convolution_pd.hpp"
1819
#include "common/memory.hpp"
1920
#include "common/memory_tracking.hpp"
2021
#include "common/nstl.hpp"
@@ -1368,6 +1369,11 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
13681369
const memory_desc_wrapper dst_d(&dst_md);
13691370
const memory_desc_wrapper bias_d(&bias_md);
13701371

1372+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
1373+
// TODO: change data type of jcp fields to size_t
1374+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
1375+
VERBOSE_BAD_PARAM, "Large size is not supported");
1376+
13711377
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
13721378
const int ndims = src_d.ndims();
13731379
const bool is_1d = ndims == 3;

src/cpu/x64/jit_brdgmm_dw_conv.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) {
162162
const memory_desc_wrapper dst_d(&dst_md_);
163163
const memory_desc_wrapper bias_d(&bias_md_);
164164

165+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
166+
// TODO: change data type of jcp fields to size_t
167+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
168+
VERBOSE_BAD_PARAM, "Large size is not supported");
169+
165170
const int ndims = src_d.ndims();
166171
const bool is_3d = ndims == 5;
167172
// Currently this kernel only supports 2D and 3D convolutions.

src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,11 @@ status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
14151415
const memory_desc_wrapper diff_src_d(&diff_src_md);
14161416
const memory_desc_wrapper bias_d(&bias_md);
14171417

1418+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
1419+
// TODO: change data type of jcp fields to size_t
1420+
VDISPATCH_CONV_IC(!has_large_size(cd, diff_src_d, weights_d, diff_dst_d),
1421+
VERBOSE_BAD_PARAM, "Large size is not supported");
1422+
14181423
const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
14191424
int ndims = diff_src_d.ndims();
14201425

src/cpu/x64/jit_brgemm_conv_utils.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1662,6 +1662,11 @@ status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
16621662
const memory_desc_wrapper dst_d(&dst_md);
16631663
const memory_desc_wrapper bias_d(&bias_md);
16641664

1665+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
1666+
// TODO: change data type of jcp fields to size_t
1667+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
1668+
VERBOSE_BAD_PARAM, "Large size is not supported");
1669+
16651670
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
16661671
int ndims = src_d.ndims();
16671672

src/cpu/x64/jit_primitive_conf.hpp

+19
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,25 @@ inline status_t init_tag(format_tag_t &tag, const memory_desc_wrapper &mdw,
276276
return tag == tag_value ? status::success : status::unimplemented;
277277
}
278278

279+
inline bool has_large_size(const convolution_desc_t &cd,
280+
const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
281+
const memory_desc_wrapper &dst_d) {
282+
auto is_large = [](const dim_t val) { return val > INT_MAX; };
283+
284+
if (utils::one_of(true, is_large(src_d.nelems()),
285+
is_large(weights_d.nelems()), is_large(dst_d.nelems())))
286+
return true;
287+
288+
const int ndims = src_d.ndims();
289+
for (int d = 3; d <= ndims; d++) {
290+
if (utils::one_of(true, is_large(cd.strides[ndims - d]),
291+
is_large(cd.padding[0][ndims - d]),
292+
is_large(cd.dilates[ndims - d])))
293+
return true;
294+
}
295+
return false;
296+
}
297+
279298
struct jit_conv_call_s {
280299
const void *src; /* hack, non-const for backward_data */
281300
const void *dst; /* hack, non-const for forward */

src/cpu/x64/jit_sse41_1x1_conv_kernel_f32.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,11 @@ status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
554554
// disabling verbose dispatch messages for unsupported isa for better readability
555555
if (!mayiuse(sse41)) return status::unimplemented;
556556

557+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
558+
// TODO: change data type of jcp fields to size_t
559+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
560+
VERBOSE_BAD_PARAM, "Large size is not supported");
561+
557562
// TODO (Roma): this code is duplicated from the generic kernel; maybe the
558563
// configuration struct could do some stuff below
559564
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;

src/cpu/x64/jit_sse41_conv_kernel_f32.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2017-2024 Intel Corporation
2+
* Copyright 2017-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -386,6 +386,11 @@ status_t jit_sse41_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
386386
// disabling verbose dispatch messages for unsupported isa for better readability
387387
if (!mayiuse(sse41)) return status::unimplemented;
388388

389+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
390+
// TODO: change data type of jcp fields to size_t
391+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
392+
VERBOSE_BAD_PARAM, "Large size is not supported");
393+
389394
jcp.nthr = nthreads;
390395

391396
jcp.prop_kind = cd.prop_kind;

src/cpu/x64/jit_uni_dw_conv_kernel_utils.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2021-2024 Intel Corporation
2+
* Copyright 2021-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -39,6 +39,11 @@ status_t jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(
3939
const memory_desc_wrapper dst_d(&dst_md);
4040
const memory_desc_wrapper bias_d(&bias_md);
4141

42+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
43+
// TODO: change data type of jcp fields to size_t
44+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
45+
VERBOSE_BAD_PARAM, "Large size is not supported");
46+
4247
const int ndims = src_d.ndims();
4348
// Currently this kernel only supports 2D convolutions.
4449
VDISPATCH_CONV_IC(ndims == 4, "kernel supports only 2D convolutions");
@@ -279,6 +284,11 @@ status_t jit_uni_dw_conv_bwd_data_kernel<isa, kernel_dt>::init_conf(
279284
const memory_desc_wrapper weights_d(&weights_md);
280285
const memory_desc_wrapper diff_dst_d(&diff_dst_md);
281286

287+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
288+
// TODO: change data type of jcp fields to size_t
289+
VDISPATCH_CONV_IC(!has_large_size(cd, diff_src_d, weights_d, diff_dst_d),
290+
VERBOSE_BAD_PARAM, "Large size is not supported");
291+
282292
jcp.dsrc_dt = cd.diff_src_desc.data_type;
283293
const bool is_bf16 = diff_dst_d.data_type() == bf16;
284294
jcp.isa = (is_bf16 && mayiuse(avx512_core_bf16)) ? avx512_core_bf16 : isa;
@@ -452,6 +462,11 @@ status_t jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::init_conf(
452462
const memory_desc_wrapper diff_bias_d(&diff_bias_md);
453463
const memory_desc_wrapper diff_dst_d(&diff_dst_md);
454464

465+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
466+
// TODO: change data type of jcp fields to size_t
467+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, diff_weights_d, diff_dst_d),
468+
VERBOSE_BAD_PARAM, "Large size is not supported");
469+
455470
jcp.dwei_dt = cd.diff_weights_desc.data_type;
456471
const int ndims = src_d.ndims();
457472
const bool is_bf16 = src_d.data_type() == data_type::bf16;

src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,11 @@ status_t jit_uni_x8s8s32x_1x1_conv_kernel<isa>::init_conf(
640640
// disabling verbose dispatch messages for unsupported isa for better readability
641641
if (!mayiuse(isa)) return status::unimplemented;
642642

643+
// Big int (> INT_MAX) values are unsupported and jcp fields may overflow
644+
// TODO: change data type of jcp fields to size_t
645+
VDISPATCH_CONV_IC(!has_large_size(cd, src_d, weights_d, dst_d),
646+
VERBOSE_BAD_PARAM, "Large size is not supported");
647+
643648
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
644649
const bool dt_not_ok
645650
= !one_of(src_d.data_type(), data_type::u8, data_type::s8)

0 commit comments

Comments
 (0)