Skip to content

Commit 63220c8

Browse files
luweizhou2016azhai219
authored andcommitted
[FORK][FEATURE] Support of strided blobs for [de]convolution and simple reorder
Relaxed mb strides check for FP32/BF16 Convolutions
1 parent a7412f3 commit 63220c8

18 files changed

+154
-66
lines changed

src/common/memory_desc_wrapper.hpp

+21
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,12 @@ struct memory_desc_wrapper : public c_compatible {
451451
return memory_desc_matches_tag(*md_, tag);
452452
}
453453

454+
/** returns true if the memory desc corresponds to the given format tag and
455+
* strides.
456+
* @sa memory_desc_matches_tag */
457+
bool matches_tag(format_tag_t tag, const dims_t strides) const {
458+
return memory_desc_matches_tag(*md_, tag, strides);
459+
}
454460
/** returns matching tag (or undef if match is not found)
455461
* XXX: This is a workaround that eventually should go away! */
456462
template <typename... Tags>
@@ -461,6 +467,21 @@ struct memory_desc_wrapper : public c_compatible {
461467
return format_tag::undef;
462468
}
463469

470+
/** returns matching tag (or undef if match is not found) with taking into
471+
* account strides specified outside */
472+
template<typename ...Tags>
473+
dnnl_format_tag_t stride_relaxed_matches_any_of(const dims_t &strides, Tags... tags) const {
474+
for (const auto &tag : {tags...})
475+
if (matches_tag(tag, strides)) return tag;
476+
return format_tag::undef;
477+
}
478+
479+
template<typename ...Tags>
480+
dnnl_format_tag_t mb_stride_relaxed_match(Tags... tags) const {
481+
const dims_t skip_mb_stride{-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
482+
return stride_relaxed_matches_any_of(skip_mb_stride, tags...);
483+
}
484+
464485
/* offset section */
465486

466487
/** returns physical offset by logical one. logical offset is represented by

src/common/type_helpers.hpp

+45-3
Original file line numberDiff line numberDiff line change
@@ -329,10 +329,7 @@ inline bool blocking_desc_is_equal(const memory_desc_t &lhs_md,
329329
&& array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks);
330330
if (ignore_strides) return equal;
331331

332-
// Check the strides.
333-
// Note: for dimensions of size `1` the stride doesn't really matter.
334332
for (int d = 0; d < lhs_md.ndims; ++d) {
335-
if (lhs_md.dims[d] == 1 && lhs_md.padded_dims[d] == 1) continue;
336333
equal = equal && lhs.strides[d] == rhs.strides[d];
337334
}
338335

@@ -1177,6 +1174,51 @@ inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag) {
11771174
return types::blocking_desc_is_equal(md, md_gold);
11781175
}
11791176

1177+
/** returns true if memory desc @p md corresponds to the given format tag and
1178+
* strides.
1179+
* In order to align with memory descriptor equality comparisons and hashing,
1180+
* the strides of unit dimensions are ignored.
1181+
* Strides might contain `0` value, indicating the stride must match the one
1182+
* that memory_desc_init_by_tag() returns.
1183+
* Strides might contain `-1` values, that would be ignored during the
1184+
* comparison. For instance, this can be used if a stride along minibatch
1185+
* doesn't matter. */
1186+
inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag,
1187+
const dims_t strides) {
1188+
if (md.format_kind != format_kind::sparse) {
1189+
if (md.format_kind != types::format_tag_to_kind(tag)) return false;
1190+
}
1191+
memory_desc_t md_gold;
1192+
status_t status = memory_desc_init_by_tag(
1193+
md_gold, md.ndims, md.dims, md.data_type, tag);
1194+
if (status != status::success) return false;
1195+
1196+
if (md.format_kind != format_kind::blocked)
1197+
return false; // unimplemented yet
1198+
1199+
const auto &blk = md.format_desc.blocking;
1200+
const auto &blk_gold = md_gold.format_desc.blocking;
1201+
1202+
using utils::array_cmp;
1203+
bool same_blocks = true && blk.inner_nblks == blk_gold.inner_nblks
1204+
&& array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks)
1205+
&& array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks);
1206+
1207+
if (!same_blocks) return false;
1208+
1209+
if (strides == nullptr)
1210+
return array_cmp(blk.strides, blk_gold.strides, md.ndims);
1211+
1212+
for (int d = 0; d < md.ndims; ++d) {
1213+
dim_t stride = strides[d];
1214+
if (stride == -1) continue;
1215+
if (stride == 0) stride = blk_gold.strides[d];
1216+
if (blk.strides[d] != stride) return false;
1217+
}
1218+
1219+
return true;
1220+
}
1221+
11801222
/** returns matching tag (or undef if match is not found)
11811223
* XXX: This is a workaround that eventually should go away! */
11821224
template <typename... Tags>

src/cpu/gemm_convolution.cpp

+21-7
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,20 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp(
231231

232232
const conv_gemm_conf_t &jcp = this->pd()->jcp_;
233233

234-
const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
234+
const memory_desc_wrapper src_d(pd()->src_md());
235+
const memory_desc_wrapper dst_d(pd()->dst_md());
236+
237+
const size_t src_mb_stride = src_d.blk_off(1);
238+
const size_t src_grp_stride = src_d.blk_off(0, 1) * jcp.ic;
239+
240+
const size_t dst_mb_stride = dst_d.blk_off(1);
241+
const size_t dst_grp_stride = dst_d.blk_off(0, 1) * jcp.oc;
242+
235243
const size_t weights_oc_size = jcp.ic * jcp.ks;
236244
const size_t weights_g_size = weights_oc_size * jcp.oc;
237245
const bool is_problem_3d = pd()->ndims() == 5;
246+
src += src_d.off_l(0);
247+
dst += dst_d.off_l(0);
238248

239249
assert(IMPLICATION(is_problem_3d,
240250
jcp.os_block == jcp.os && jcp.ic_block == jcp.ic
@@ -254,7 +264,7 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp(
254264
auto inner_ker = [&](int spatial, const im_pos_t &curr, im_pos_t &prev,
255265
im_pos_t &step, const im_pos_t &end) {
256266
const data_t *_src
257-
= src + (curr.n * jcp.ngroups + curr.g) * src_step;
267+
= src + curr.n * src_mb_stride + curr.g * src_grp_stride;
258268
step.oc = nstl::min(
259269
jcp.oc_block, nstl::min(jcp.oc, end.oc) - curr.oc);
260270
step.sp = nstl::min(jcp.os_block,
@@ -275,10 +285,9 @@ status_t gemm_convolution_fwd_t::execute_forward_ncsp(
275285
const data_t one = 1.0;
276286

277287
const dim_t M = jcp.os * jcp.od;
278-
const size_t dst_step = jcp.oc * M;
279288
const dim_t m = step.sp;
280289
const dim_t LDA = jcp.im2col_sz ? m : M;
281-
data_t *_dst = dst + (curr.n * jcp.ngroups + curr.g) * dst_step
290+
data_t *_dst = dst + curr.n * dst_mb_stride + curr.g * dst_grp_stride
282291
+ curr.oc * M + curr.od * jcp.os + curr.sp;
283292
const dim_t K = step.ic * jcp.ks;
284293
const dim_t LDB = jcp.ic * jcp.ks;
@@ -522,8 +531,13 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp(
522531
const conv_gemm_conf_t &jcp = this->pd()->jcp_;
523532

524533
const dim_t M = jcp.os * jcp.od;
525-
const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id;
526-
const size_t dst_step = (size_t)jcp.oc * M;
534+
const size_t src_step_to_clean = jcp.ic * jcp.ih * jcp.iw * jcp.id;
535+
const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
536+
const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
537+
const size_t src_step = diff_src_d.blk_off(1) / jcp.ngroups;
538+
const size_t dst_step = diff_dst_d.blk_off(1) / jcp.ngroups;
539+
diff_src += diff_src_d.off_l(0);
540+
diff_dst += diff_dst_d.off_l(0);
527541
const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
528542

529543
const dim_t m = jcp.os_block;
@@ -547,7 +561,7 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp(
547561
if (is_problem_3d && jcp.im2col_sz > 0) {
548562
// jit_gemm_convolution_utils::col2im_3d() assumes that the
549563
// accumulator is initialized by zeroes
550-
for (size_t i = 0; i < src_step; i++)
564+
for (size_t i = 0; i < src_step_to_clean; i++)
551565
_diff_src[i] = (data_t)0;
552566
}
553567

src/cpu/gemm_convolution_utils.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -1080,16 +1080,14 @@ status_t init_conf(conv_gemm_conf_t &jcp,
10801080
CHECK(memory_desc_init_by_tag(src_md, desired_src_tag));
10811081
src_tag = desired_src_tag;
10821082
} else {
1083-
src_tag = memory_desc_matches_one_of_tag(
1084-
src_md, nwc, nhwc, ndhwc, ncw, nchw, ncdhw);
1083+
src_tag = src_d.mb_stride_relaxed_match(nwc, nhwc, ndhwc, ncw, nchw, ncdhw);
10851084
}
10861085

10871086
if (dst_d.format_kind() == format_kind::any) {
10881087
CHECK(memory_desc_init_by_tag(dst_md, desired_dst_tag));
10891088
dst_tag = desired_dst_tag;
10901089
} else {
1091-
dst_tag = memory_desc_matches_one_of_tag(
1092-
dst_md, nwc, nhwc, ndhwc, ncw, nchw, ncdhw);
1090+
dst_tag = dst_d.mb_stride_relaxed_match(nwc, nhwc, ndhwc, ncw, nchw, ncdhw);
10931091
}
10941092

10951093
if (src_tag == format_tag::undef || dst_tag == format_tag::undef)

src/cpu/nchw_pooling.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,18 @@ template <>
3636
status_t nchw_pooling_fwd_t<data_type::f32>::execute_forward(
3737
const exec_ctx_t &ctx) const {
3838
const auto alg = pd()->desc()->alg_kind;
39-
const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
39+
const auto src_ = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
4040
auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
4141
auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
4242

4343
const memory_desc_wrapper ws_d(pd()->workspace_md());
44+
const memory_desc_wrapper src_d(pd()->src_md());
45+
const memory_desc_wrapper dst_d(pd()->dst_md());
4446
const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
4547

48+
auto src = src_ + src_d.off_l(0);
49+
dst += dst_d.off_l(0);
50+
4651
const dim_t MB = pd()->MB();
4752
const dim_t C = pd()->OC();
4853
const dim_t OD = pd()->OD();

src/cpu/x64/gemm_bf16_convolution.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,12 @@ status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_ncsp(
531531

532532
const conv_gemm_conf_t &jcp = this->pd()->jcp_;
533533

534+
const memory_desc_wrapper src_d(pd()->src_md());
535+
const memory_desc_wrapper dst_d(pd()->dst_md());
536+
537+
src += src_d.off_l(0);
538+
dst += dst_d.off_l(0);
539+
534540
float *bias = nullptr;
535541
if (jcp.with_bias) {
536542
if (pd()->desc()->bias_desc.data_type == data_type::bf16) {

src/cpu/x64/jit_avx2_1x1_conv_kernel_f32.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -762,8 +762,8 @@ status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
762762

763763
const auto dat_tag_nxc = utils::pick(ndims - 3, nwc, nhwc, ndhwc);
764764
const auto dat_tag_nCx8c = utils::pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
765-
jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
766-
jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
765+
jcp.src_tag = src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx8c);
766+
jcp.dst_tag = dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx8c);
767767
const bool is_data_layout_nxc
768768
= utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
769769
const auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;

src/cpu/x64/jit_avx2_conv_kernel_f32.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -651,9 +651,9 @@ status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
651651
: pick(ndims - 3, Owi8o, Ohwi8o, Odhwi8o);
652652

653653
jcp.src_tag
654-
= src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_ncx, dat_tag_nCx8c);
654+
= src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_ncx, dat_tag_nCx8c);
655655
jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag_OIxio, wei_tag_Oxio);
656-
jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
656+
jcp.dst_tag = dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx8c);
657657

658658
jcp.typesize_in = types::data_type_size(src_d.data_type());
659659
jcp.typesize_out = types::data_type_size(dst_d.data_type());
@@ -1170,8 +1170,8 @@ status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp,
11701170
? pick(ndims - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i)
11711171
: pick(ndims - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i);
11721172

1173-
jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
1174-
jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
1173+
jcp.src_tag = diff_src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx8c);
1174+
jcp.dst_tag = diff_dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx8c);
11751175
jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
11761176

11771177
jcp.typesize_in = types::data_type_size(diff_src_d.data_type());

src/cpu/x64/jit_avx512_common_1x1_conv_kernel.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -638,8 +638,8 @@ status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp,
638638

639639
const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
640640
const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
641-
jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
642-
jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
641+
jcp.src_tag = src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx16c);
642+
jcp.dst_tag = dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx16c);
643643
bool is_data_layout_nxc
644644
= utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
645645
auto required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;

src/cpu/x64/jit_avx512_common_conv_kernel.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -845,9 +845,9 @@ status_t jit_avx512_common_conv_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
845845
const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
846846
const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
847847
const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
848-
auto curr_src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c,
848+
auto curr_src_tag = src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx16c,
849849
dat_tag_nCx8c, dat_tag_nCx4c, dat_tag_ncx);
850-
auto curr_dst_tag = dst_d.matches_one_of_tag(
850+
auto curr_dst_tag = dst_d.mb_stride_relaxed_match(
851851
dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
852852
bool is_data_layout_nxc = IMPLICATION(curr_src_tag != dat_tag_nxc,
853853
src_d.format_kind() == format_kind::any)
@@ -1913,9 +1913,9 @@ status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
19131913
const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
19141914
const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
19151915
const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1916-
auto curr_src_tag = diff_src_d.matches_one_of_tag(
1916+
auto curr_src_tag = diff_src_d.mb_stride_relaxed_match(
19171917
dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
1918-
auto curr_dst_tag = diff_dst_d.matches_one_of_tag(
1918+
auto curr_dst_tag = diff_dst_d.mb_stride_relaxed_match(
19191919
dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
19201920
bool is_data_layout_nxc
19211921
= IMPLICATION(curr_src_tag != dat_tag_nxc,

src/cpu/x64/jit_avx512_common_convolution.cpp

+17-17
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ void jit_avx512_common_convolution_fwd_t<src_type, wei_type,
215215
start_copy = start;
216216

217217
auto par_conv = jit_conv_call_s();
218-
size_t src_c_stride = src_d.blk_off(0, 1);
218+
size_t src_c_stride = src_d.blk_off(0, 1) - src_d.off_l(0);
219219
size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
220220

221221
for (int icb_l2 = 0; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
@@ -338,9 +338,9 @@ void jit_avx512_common_convolution_fwd_t<src_type, wei_type,
338338
start_copy = start;
339339

340340
auto par_conv = jit_conv_call_s();
341-
size_t src_h_stride = src_d.blk_off(0, 0, 1);
342-
size_t src_c_stride = src_d.blk_off(0, 1);
343-
size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
341+
size_t src_h_stride = src_d.blk_off(0, 0, 1) - src_d.off_l(0);
342+
size_t src_c_stride = src_d.blk_off(0, 1) - src_d.off_l(0);
343+
size_t dst_h_stride = dst_d.blk_off(0, 0, 1) - dst_d.off_l(0);
344344
size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
345345
size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
346346

@@ -495,10 +495,10 @@ void jit_avx512_common_convolution_fwd_t<src_type, wei_type,
495495
start_copy = start;
496496

497497
auto par_conv = jit_conv_call_s();
498-
size_t src_d_stride = src_d.blk_off(0, 0, 1);
499-
size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
500-
size_t src_c_stride = src_d.blk_off(0, 1);
501-
size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
498+
size_t src_d_stride = src_d.blk_off(0, 0, 1) - src_d.off_l(0);
499+
size_t src_h_stride = src_d.blk_off(0, 0, 0, 1) - src_d.off_l(0);
500+
size_t src_c_stride = src_d.blk_off(0, 1) - src_d.off_l(0);
501+
size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1) - dst_d.off_l(0);
502502
size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
503503
size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
504504
size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
@@ -653,7 +653,7 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
653653
start_copy = start;
654654

655655
auto par_conv = jit_conv_call_s();
656-
size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
656+
size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1) - diff_dst_d.off_l(0);
657657
size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
658658

659659
for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
@@ -762,9 +762,9 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
762762
start_copy = start;
763763

764764
auto par_conv = jit_conv_call_s();
765-
size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1);
766-
size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1);
767-
size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
765+
size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1) - diff_src_d.off_l(0);
766+
size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1) - diff_dst_d.off_l(0);
767+
size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1) - diff_dst_d.off_l(0);
768768
size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
769769
size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
770770

@@ -923,11 +923,11 @@ void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
923923
start_copy = start;
924924

925925
auto par_conv = jit_conv_call_s();
926-
size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1);
927-
size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1);
928-
size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1);
929-
size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1);
930-
size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
926+
size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1) - diff_src_d.off_l(0);
927+
size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1) - diff_src_d.off_l(0);
928+
size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1) - diff_dst_d.off_l(0);
929+
size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1) - diff_dst_d.off_l(0);
930+
size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1) - diff_dst_d.off_l(0);
931931
size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
932932
size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
933933
size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);

src/cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1280,8 +1280,8 @@ status_t jit_avx512_core_bf16_1x1_conv_kernel::init_conf(
12801280
using namespace format_tag;
12811281
const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
12821282
const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1283-
jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
1284-
jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
1283+
jcp.src_tag = src_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx16c);
1284+
jcp.dst_tag = dst_d.mb_stride_relaxed_match(dat_tag_nxc, dat_tag_nCx16c);
12851285
bool is_data_layout_nxc
12861286
= utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
12871287
auto required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;

src/cpu/x64/jit_avx512_core_bf16_1x1_convolution.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t<dst_type>::execute_forward_thr(
179179
const bool is_src_layout_nxc = utils::one_of(
180180
jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
181181

182+
auto start_off = dst_d.off_l(0);
183+
182184
auto step = [](int default_step, int remaining, int tail_step) {
183185
assert(default_step <= tail_step);
184186
return remaining < tail_step ? remaining : default_step;
@@ -267,7 +269,7 @@ void jit_avx512_core_bf16_1x1_convolution_fwd_t<dst_type>::execute_forward_thr(
267269
: rnd_up((jcp.load_dim / grp_count), jcp.load_block);
268270
const size_t str_size = jcp.bcast_dim * max_load_per_thread;
269271
p.store_buffer = store_buffer + ithr * str_size
270-
+ data_blk_off(dst_d, 0, 0, od, oh, ow);
272+
+ data_blk_off(dst_d, 0, 0, od, oh, ow) - start_off;
271273

272274
p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
273275
p.dst_orig = static_cast<const char *>(p.output_data)

0 commit comments

Comments
 (0)