Skip to content

Commit b0c50e8

Browse files
alexey-varyzginluweizhou2016
authored andcommitted
[FIX] [1D] Enlarge support
1 parent ebd6bdb commit b0c50e8

7 files changed

+44
-33
lines changed

src/cpu/reorder/simple_reorder.hpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -1322,8 +1322,10 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
13221322

13231323
template <SIMPLE_REORDER_TEMPL_DECL>
13241324
struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1325-
typename utils::enable_if<(tag_i == format_tag::nchw
1326-
&& tag_o == format_tag::nChw16c)
1325+
typename utils::enable_if<((tag_i == format_tag::nchw
1326+
&& tag_o == format_tag::nChw16c) ||
1327+
(tag_i == format_tag::ncw
1328+
&& tag_o == format_tag::nCw16c))
13271329
&& type_i == data_type::f32
13281330
&& type_o == data_type::bf16>::type> {
13291331
static bool is_applicable(const memory_desc_wrapper &input_d,
@@ -1339,23 +1341,25 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
13391341

13401342
static size_t get_scratchpad_size(const memory_desc_wrapper &input_d,
13411343
const memory_desc_wrapper &output_d) {
1344+
constexpr int ndims = tag_traits<tag_i>::ndims;
13421345
const size_t blksize = 16;
1343-
const size_t W = input_d.dims()[3];
1346+
const size_t W = input_d.dims()[ndims - 1];
13441347
return sizeof(float) * blksize * W * dnnl_get_max_threads();
13451348
}
13461349

13471350
static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
13481351
DECLARE_COMMON_PARAMS();
13491352

13501353
const dim_t blksize = 16;
1354+
const dim_t ndims = tag_traits<tag_i>::ndims;
13511355

13521356
const auto &flat_d = input_d;
13531357
const auto &dims = input_d.dims();
13541358
const auto &pdims = output_d.padded_dims();
13551359

13561360
const dim_t C = dims[1];
1357-
const dim_t H = dims[2];
1358-
const dim_t W = dims[3];
1361+
const dim_t H = ndims == 3 ? 1 : dims[ndims - 2];
1362+
const dim_t W = dims[ndims - 1];
13591363

13601364
const dim_t wsp_size = W * blksize;
13611365
float *wspace = scratchpad.template get<float>(
@@ -1368,7 +1372,7 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
13681372
for (c = 0; c < curr_c_block; ++c) {
13691373
const ptrdiff_t flat_off = 0
13701374
+ c * flat_d.blocking_desc().strides[1]
1371-
+ w * flat_d.blocking_desc().strides[3];
1375+
+ w * flat_d.blocking_desc().strides[ndims - 1];
13721376
o[w * blksize + c] = i[flat_off];
13731377
}
13741378
for (/* continue */; c < c_block; ++c) {

src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1491,7 +1491,7 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
14911491
return status::unimplemented;
14921492
}
14931493

1494-
if (jcp.with_input_zp && jcp.is_depthwise && ndims != 4)
1494+
if (jcp.with_input_zp && jcp.is_depthwise && ndims == 5)
14951495
return status::unimplemented;
14961496

14971497
if (jcp.with_weights_zp)

src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d(
9797
size_t ch_offset = jcp.is_depthwise ? jcp.nb_ch * jcp.ch_block
9898
: jcp.ngroups * jcp.oc;
9999
auto w = const_cast<char *>(weights);
100-
int32_t *compensation = (jcp.signed_input)
101-
? reinterpret_cast<int32_t *>(&w[extra_data_offset])
102-
: nullptr;
100+
int32_t *compensation = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[extra_data_offset]) :
101+
(jcp.with_input_zp) ? pd()->attr()->output_compensations_.shifts_ : nullptr;
102+
const uint8_t* input_zp = pd()->attr()->input_zero_points_.shifts_;
103103
int32_t *zp_compensation = jcp.src_zero_point
104104
? reinterpret_cast<int32_t *>(&w[extra_data_offset])
105105
+ (jcp.signed_input ? ch_offset : 0)
@@ -147,7 +147,7 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d(
147147

148148
p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
149149
: nullptr;
150-
p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr;
150+
p.compensation = (jcp.signed_input || jcp.with_input_zp) ? compensation + g_oc : nullptr;
151151
p.zp_compensation
152152
= jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
153153
p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr;
@@ -166,6 +166,8 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d(
166166
p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
167167
p.dst_orig = dst;
168168
p.oc_off = g_oc * sizeof(float);
169+
if (jcp.with_input_zp)
170+
p.input_zp = input_zp + g_ic;
169171

170172
(*kernel_)(&p);
171173

src/cpu/x64/jit_uni_fork_dw_conv_kernel_utils.hpp

+14-14
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,14 @@ status_t jit_uni_fork_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(
104104
const memory_desc_wrapper bias_d(&bias_md);
105105

106106
const int ndims = src_d.ndims();
107-
// Currently this kernel only supports 2D and 3D convolutions.
108-
if (ndims != 4 && ndims != 5) return status::unimplemented;
109107

110-
const auto blocked_tag = (ndims == 5) ? one_of(isa, avx512_core, avx512_core) ? nCdhw16c : nCdhw8c
111-
: one_of(isa, avx512_core, avx512_core) ? nChw16c : nChw8c;
112-
const auto wei_tag = (ndims == 5) ? one_of(isa, avx512_core, avx512_core) ? Goidhw16g : Goidhw8g
113-
: one_of(isa, avx512_core, avx512_core) ? Goihw16g : Goihw8g;
114-
const auto nxc_tag = (ndims == 5) ? ndhwc : nhwc;
108+
const auto blocked_tag = one_of(isa, avx512_core) ?
109+
pick(ndims - 3, nCw16c, nChw16c, nCdhw16c) :
110+
pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
111+
const auto wei_tag = one_of(isa, avx512_core) ?
112+
pick(ndims - 3, Goiw16g, Goihw16g, Goidhw16g) :
113+
pick(ndims - 3, Goiw8g, Goihw8g, Goidhw8g);
114+
const auto nxc_tag = pick(ndims - 3, nwc, nhwc, ndhwc);
115115

116116
jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
117117

@@ -172,29 +172,29 @@ status_t jit_uni_fork_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(
172172
jcp.ic = src_d.dims()[1];
173173

174174
jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
175-
jcp.ih = src_d.dims()[ndims - 2];
175+
jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
176176
jcp.iw = src_d.dims()[ndims - 1];
177177
jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
178-
jcp.oh = dst_d.dims()[ndims - 2];
178+
jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
179179
jcp.ow = dst_d.dims()[ndims - 1];
180180

181181
jcp.kd = (ndims == 5) ? weights_d.dims()[3] : 1;
182-
jcp.kh = weights_d.dims()[ndims - 1];
182+
jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[ndims - 1];
183183
jcp.kw = weights_d.dims()[ndims];
184184

185185
jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
186-
jcp.t_pad = cd.padding[0][ndims - 4];
186+
jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
187187
jcp.l_pad = cd.padding[0][ndims - 3];
188188
jcp.back_pad = (ndims == 5) ? cd.padding[1][0] : 0;
189-
jcp.b_pad = cd.padding[1][ndims - 4];
189+
jcp.b_pad = (ndims == 3) ? 0 : cd.padding[1][ndims - 4];
190190
jcp.r_pad = cd.padding[1][ndims - 3];
191191

192192
jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
193-
jcp.stride_h = cd.strides[ndims - 4];
193+
jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
194194
jcp.stride_w = cd.strides[ndims - 3];
195195

196196
jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
197-
jcp.dilate_h = cd.dilates[ndims - 4];
197+
jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
198198
jcp.dilate_w = cd.dilates[ndims - 3];
199199

200200
jcp.loop_order = loop_ngcw;

src/cpu/x64/jit_uni_fork_dw_convolution.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,12 @@ void jit_uni_fork_dw_convolution_fwd_t<isa, src_type, dst_type>::execute_forward
9595
const auto ic_off_idx = is_src_layout_nxc ? ch * jcp.ch_block : ch;
9696
const auto oc_off_idx = is_dst_layout_nxc ? ch * jcp.ch_block : ch;
9797

98-
size_t src_off = (jcp.ndims == 5) ? src_d.blk_off(n, ic_off_idx, id, ih, iw) : src_d.blk_off(n, ic_off_idx, ih, iw);
99-
size_t dst_off = (jcp.ndims == 5) ? dst_d.blk_off(n, oc_off_idx, od, oh, ow) : dst_d.blk_off(n, oc_off_idx, oh, ow);
100-
size_t wei_off = (jcp.ndims == 5) ? weights_d.blk_off(ch, 0, 0, kd, kh, kw) : weights_d.blk_off(ch, 0, 0, kh, kw);
98+
size_t src_off = (jcp.ndims == 3) ? src_d.blk_off(n, ic_off_idx, iw) :
99+
(jcp.ndims == 4) ? src_d.blk_off(n, ic_off_idx, ih, iw) : src_d.blk_off(n, ic_off_idx, id, ih, iw);
100+
size_t dst_off = (jcp.ndims == 3) ? dst_d.blk_off(n, oc_off_idx, ow) :
101+
(jcp.ndims == 4) ? dst_d.blk_off(n, oc_off_idx, oh, ow) : dst_d.blk_off(n, oc_off_idx, od, oh, ow);
102+
size_t wei_off = (jcp.ndims == 3) ? weights_d.blk_off(ch, 0, 0, kw) :
103+
(jcp.ndims == 4) ? weights_d.blk_off(ch, 0, 0, kh, kw) : weights_d.blk_off(ch, 0, 0, kd, kh, kw);
101104

102105
par_conv.src = &src[src_off];
103106
par_conv.dst = &dst[dst_off];

src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1367,7 +1367,7 @@ status_t jit_uni_x8s8s32x_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
13671367
return status::unimplemented;
13681368
}
13691369

1370-
if (jcp.with_input_zp && jcp.is_depthwise && ndims != 4)
1370+
if (jcp.with_input_zp && jcp.is_depthwise && !utils::one_of(ndims, 3, 4))
13711371
return status::unimplemented;
13721372

13731373
if (jcp.with_weights_zp)

src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,9 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t<isa>::execute_forward_1d(
267267
size_t ch_offset = jcp.is_depthwise ? jcp.nb_ch * jcp.ch_block
268268
: jcp.ngroups * jcp.oc;
269269
auto w = const_cast<char *>(weights);
270-
const int32_t *compensation = (jcp.signed_input)
271-
? reinterpret_cast<int32_t *>(&w[extra_data_offset])
272-
: nullptr;
270+
const int32_t *compensation = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[extra_data_offset]) :
271+
(jcp.with_input_zp) ? pd()->attr()->output_compensations_.shifts_ : nullptr;
272+
const uint8_t* input_zp = pd()->attr()->input_zero_points_.shifts_;
273273
const int32_t *zp_compensation = jcp.src_zero_point
274274
? reinterpret_cast<int32_t *>(&w[extra_data_offset])
275275
+ (jcp.signed_input ? ch_offset : 0)
@@ -316,7 +316,7 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t<isa>::execute_forward_1d(
316316

317317
p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
318318
: nullptr;
319-
p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr;
319+
p.compensation = (jcp.signed_input || jcp.with_input_zp) ? compensation + g_oc : nullptr;
320320
p.zp_compensation
321321
= jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
322322
p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr;
@@ -335,6 +335,8 @@ status_t jit_uni_x8s8s32x_convolution_fwd_t<isa>::execute_forward_1d(
335335
p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
336336
p.dst_orig = dst;
337337
p.oc_off = g_oc * sizeof(float);
338+
if (jcp.with_input_zp)
339+
p.input_zp = input_zp + g_ic;
338340

339341
(*kernel_)(&p);
340342

0 commit comments

Comments
 (0)