Skip to content

Commit 4bd1241

Browse files
dmitry-gorokhovazhai219
authored andcommitted
[FORK][FEATURE] Added 3D DW case support for JIT INT8 Convolutions
1 parent 685bce3 commit 4bd1241

12 files changed

+245
-17
lines changed

include/oneapi/dnnl/dnnl.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,7 @@ struct memory : public handle<dnnl_memory_t> {
14541454
BAcde16b16a = dnnl_BAcde16b16a,
14551455
BAcde16a16b = dnnl_BAcde16a16b,
14561456
aBdec32b = dnnl_aBdec32b,
1457+
Abcdef4a = dnnl_Abcdef4a,
14571458
Abcdef8a = dnnl_Abcdef8a,
14581459
Abcdef16a = dnnl_Abcdef16a,
14591460
Abcdef32a = dnnl_Abcdef32a,
@@ -1702,6 +1703,7 @@ struct memory : public handle<dnnl_memory_t> {
17021703
IOdhw16i16o = dnnl_IOdhw16i16o,
17031704
gIOhw16i16o = dnnl_gIOhw16i16o,
17041705
gOhwi32o = dnnl_gOhwi32o,
1706+
Goidhw4g = dnnl_Goidhw4g,
17051707
Goidhw8g = dnnl_Goidhw8g,
17061708
Goidhw16g = dnnl_Goidhw16g,
17071709
IOw16o16i = dnnl_IOw16o16i,

include/oneapi/dnnl/dnnl_types.h

+2
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ typedef enum {
398398
dnnl_aCBdef16c16b,
399399
dnnl_aBdefc4b,
400400
dnnl_aBdefc8b,
401+
dnnl_Abcdef4a,
401402
dnnl_Abcdef8a,
402403
dnnl_Abcdef16a,
403404
dnnl_Abcdef32a,
@@ -1621,6 +1622,7 @@ typedef enum {
16211622
dnnl_gIOdhw8o16i2o = dnnl_aCBdef8b16c2b,
16221623
dnnl_gOIdhw8o8i = dnnl_aBCdef8b8c,
16231624
dnnl_gOIdhw8o4i = dnnl_aBCdef8b4c,
1625+
dnnl_Goidhw4g = dnnl_Abcdef4a,
16241626
dnnl_Goidhw8g = dnnl_Abcdef8a,
16251627
dnnl_Goidhw16g = dnnl_Abcdef16a,
16261628
dnnl_Goidhw32g = dnnl_Abcdef32a,

src/common/c_types_map.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,7 @@ const format_tag_t ABcd40a32b = dnnl_ABcd40a32b;
677677
const format_tag_t ABcde40a32b = dnnl_ABcde40a32b;
678678
const format_tag_t BAcde16b16a = dnnl_BAcde16b16a;
679679
const format_tag_t aBdec32b = dnnl_aBdec32b;
680+
const format_tag_t Abcdef4a = dnnl_Abcdef4a;
680681
const format_tag_t Abcdef8a = dnnl_Abcdef8a;
681682
const format_tag_t Abcdef16a = dnnl_Abcdef16a;
682683
const format_tag_t Abcdef32a = dnnl_Abcdef32a;
@@ -1175,6 +1176,7 @@ const format_tag_t IOhw16i16o = dnnl_IOhw16i16o;
11751176
const format_tag_t Ohwi32o = dnnl_Ohwi32o;
11761177
const format_tag_t gIOhw16i16o = dnnl_gIOhw16i16o;
11771178
const format_tag_t gOhwi32o = dnnl_gOhwi32o;
1179+
const format_tag_t Goidhw4g = dnnl_Goidhw4g;
11781180
const format_tag_t Goidhw8g = dnnl_Goidhw8g;
11791181
const format_tag_t Goidhw16g = dnnl_Goidhw16g;
11801182
const format_tag_t IOw16o16i = dnnl_IOw16o16i;

src/common/dnnl_debug_autogenerated.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
302302
if (v == dnnl_aCBdef16c16b) return "aCBdef16c16b";
303303
if (v == dnnl_aBdefc4b) return "aBdefc4b";
304304
if (v == dnnl_aBdefc8b) return "aBdefc8b";
305+
if (v == dnnl_Abcdef4a) return "Abcdef4a";
305306
if (v == dnnl_Abcdef8a) return "Abcdef8a";
306307
if (v == dnnl_Abcdef16a) return "Abcdef16a";
307308
if (v == dnnl_Abcdef32a) return "Abcdef32a";
@@ -1398,6 +1399,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
13981399
if (v == dnnl_gIOdhw8o16i2o) return "gIOdhw8o16i2o";
13991400
if (v == dnnl_gOIdhw8o8i) return "gOIdhw8o8i";
14001401
if (v == dnnl_gOIdhw8o4i) return "gOIdhw8o4i";
1402+
if (v == dnnl_Goidhw4g) return "Goidhw4g";
14011403
if (v == dnnl_Goidhw8g) return "Goidhw8g";
14021404
if (v == dnnl_Goidhw16g) return "Goidhw16g";
14031405
if (v == dnnl_Goidhw32g) return "Goidhw32g";

src/common/memory_desc_wrapper.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ status_t memory_desc_wrapper::compute_blocking(
586586
C(aBdec32b, {0, 1, 3, 4, 2}, {32}, {1});
587587
C(aCBdef16c16b, {0, 2, 1, 3, 4, 5}, {16, 16}, {2, 1});
588588
C(aCBdef16b16c, {0, 2, 1, 3, 4, 5}, {16, 16}, {1, 2});
589+
C(Abcdef4a, {0, 1, 2, 3, 4, 5}, {4}, {0});
589590
C(Abcdef8a, {0, 1, 2, 3, 4, 5}, {8}, {0});
590591
C(Abcdef16a, {0, 1, 2, 3, 4, 5}, {16}, {0});
591592
C(Abcdef32a, {0, 1, 2, 3, 4, 5}, {32}, {0});

src/common/tag_traits.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,7 @@ DECL_TRAITS(aBCde2b8c8b2c, _BC, _2b8c8b2c, 5);
795795
DECL_TRAITS(aBdec32b, _B, _32b, 5);
796796
DECL_TRAITS(aCBdef16c16b, _BC, _16c16b, 6);
797797
DECL_TRAITS(aCBdef16b16c, _BC, _16b16c, 6);
798+
DECL_TRAITS(Abcdef4a, _A, _4a, 6);
798799
DECL_TRAITS(Abcdef8a, _A, _8a, 6);
799800
DECL_TRAITS(Abcdef16a, _A, _16a, 6);
800801
DECL_TRAITS(aCBd16c16b, _BC, _16c16b, 4);

src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp

+4-7
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ void pick_loop_order(jit_conv_conf_t &jcp, int nthr) {
4444
jcp.loop_order = loop_cwgn;
4545
if (jcp.ngroups > 1) {
4646
jcp.loop_order = loop_ngcw;
47-
if (jcp.mb < nthr)
47+
if (jcp.mb < nthr && jcp.ndims != 5)
4848
jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg;
4949
} else if (jcp.mb >= nthr && jcp.ic_without_padding <= 16) {
5050
jcp.loop_order = loop_ngcw;
@@ -476,7 +476,7 @@ void _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::compute_ker_dw(int ur_w,
476476
};
477477

478478
auto kernel_offset = [this](int ci, int ki) {
479-
return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block);
479+
return jcp.typesize_in * ((ci * jcp.kd * jcp.kh * jcp.kw + ki) * jcp.ch_block);
480480
};
481481

482482
auto compute = [this](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
@@ -1480,10 +1480,6 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
14801480
}
14811481
}
14821482

1483-
if (jcp.is_depthwise && is_3d)
1484-
// NOTE: 3D depthwise is not currently supported here.
1485-
return status::unimplemented;
1486-
14871483
jcp.with_input_zp = !attr.input_zero_points_.has_default_values();
14881484
jcp.with_weights_zp = !attr.weights_zero_points_.has_default_values();
14891485

@@ -1580,7 +1576,8 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
15801576
format_tag_t wei_tag;
15811577
if (jcp.ic_block == 16 || jcp.ch_block == 16) {
15821578
if (is_3d) {
1583-
wei_tag = with_groups ? gOIdhw4i16o4i : OIdhw4i16o4i;
1579+
wei_tag = with_groups ? jcp.is_depthwise ? Goidhw16g : gOIdhw4i16o4i
1580+
: OIdhw4i16o4i;
15841581
} else if (is_1d) {
15851582
wei_tag = with_groups ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i
15861583
: OIw4i16o4i;

src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp

+106
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,112 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d(
682682
return status::success;
683683
}
684684

685+
status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d_dw(const exec_ctx_t &ctx) const {
686+
auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
687+
auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
688+
auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
689+
auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
690+
691+
const memory_desc_wrapper src_d(pd()->src_md());
692+
const memory_desc_wrapper dst_d(pd()->dst_md());
693+
const memory_desc_wrapper weights_d(pd()->weights_md(0));
694+
const memory_desc_wrapper bias_d(pd()->weights_md(1));
695+
696+
const size_t bia_dt_size
697+
= pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0;
698+
const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
699+
700+
const auto &jcp = pd()->jcp_;
701+
assert(jcp.ic_block == 1);
702+
assert(jcp.oc_block == 1);
703+
assert(jcp.nb_ic == 1);
704+
assert(jcp.nb_oc == 1);
705+
assert(jcp.nb_oc_blocking == 1);
706+
assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
707+
708+
DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
709+
DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
710+
711+
const float *oscales = adjust_oscales(
712+
ctx.get_scratchpad_grantor(), src_scales, wei_scales);
713+
714+
size_t offset = weights_d.size() - weights_d.additional_buffer_size();
715+
auto w = const_cast<char *>(weights);
716+
int32_t* compensation = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) :
717+
(jcp.with_input_zp) ? pd()->attr()->output_compensations_.shifts_ : 0;
718+
const uint8_t* input_zp = pd()->attr()->input_zero_points_.shifts_;
719+
int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
720+
int group_block = jcp.ch_block;
721+
722+
parallel_nd(jcp.mb, jcp.od, jcp.oh, jcp.nb_ow, nb_groups, [&](int n, int od_s, int oh_s, int owb, int gg) {
723+
auto p = jit_conv_call_s();
724+
725+
size_t src_d_stride = src_d.blk_off(0, 0, 1);
726+
size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
727+
728+
size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
729+
size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
730+
731+
int gb = gg * jcp.nb_ch_blocking;
732+
int g = gb * group_block;
733+
734+
int id_s = -jcp.f_pad + od_s * jcp.stride_d;
735+
736+
int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
737+
int ow_s = owb * jcp.ow_block;
738+
int iw_s = ow_s * jcp.stride_w;
739+
740+
auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0;
741+
int32_t *compensation_w = (jcp.signed_input || jcp.with_input_zp) ? compensation + g : 0;
742+
743+
auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g, od_s, oh_s, ow_s);
744+
auto src_w = src + src_d.blk_off(n, g, id_s, ih_s, iw_s);
745+
auto wht_w = weights + wht_blk_off(weights_d, gb, 0);
746+
747+
auto scales = &oscales[jcp.is_oc_scale * g];
748+
749+
int dilate_d = jcp.dilate_d + 1;
750+
int i_f_overflow = nstl::min(jcp.kd, div_up(max(0, -id_s), dilate_d));
751+
int i_back_overflow = nstl::min(jcp.kd,
752+
div_up(max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1),
753+
dilate_d));
754+
int kd_padding = nstl::max(0, jcp.kd - i_f_overflow - i_back_overflow);
755+
756+
size_t wei_d_stride = (jcp.signed_input || jcp.with_input_zp) ? 0 : i_f_overflow * wht_d_stride;
757+
758+
int dilate_h = jcp.dilate_h + 1;
759+
int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h));
760+
int i_b_overflow = nstl::min(jcp.kh,
761+
div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
762+
dilate_h));
763+
int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow);
764+
765+
size_t wei_h_stride = (jcp.signed_input || jcp.with_input_zp) ? 0 : i_t_overflow * wht_h_stride;
766+
p.src = src_w + i_t_overflow * dilate_h * src_h_stride
767+
+ i_f_overflow * dilate_d * src_d_stride;
768+
p.dst = dst_w;
769+
p.filt = wht_w + wei_d_stride + wei_h_stride;
770+
p.bias = bias_w;
771+
p.compensation = compensation_w;
772+
p.oc_blocks = gb;
773+
p.kd_padding = kd_padding;
774+
p.kh_padding = kh_padding;
775+
p.scales = scales;
776+
p.f_overflow = i_f_overflow;
777+
p.back_overflow = i_back_overflow;
778+
p.t_overflow = i_t_overflow;
779+
p.b_overflow = i_b_overflow;
780+
p.owb = owb;
781+
782+
p.oc_off = g * sizeof(float);
783+
if (jcp.with_input_zp)
784+
p.input_zp = input_zp + g;
785+
786+
(*kernel_)(&p);
787+
});
788+
return status::success;
789+
}
790+
685791
} // namespace x64
686792
} // namespace cpu
687793
} // namespace impl

src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,12 @@ struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public primitive_t {
123123
return execute_forward_2d_dw(ctx);
124124
else
125125
return execute_forward_2d(ctx);
126-
else if (_pd->ndims() == 5)
127-
return execute_forward_3d(ctx);
126+
else if (_pd->ndims() == 5) {
127+
if (_pd->jcp_.is_depthwise)
128+
return execute_forward_3d_dw(ctx);
129+
else
130+
return execute_forward_3d(ctx);
131+
}
128132
return status::unimplemented;
129133
}
130134

@@ -133,6 +137,7 @@ struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public primitive_t {
133137
status_t execute_forward_2d(const exec_ctx_t &ctx) const;
134138
status_t execute_forward_2d_dw(const exec_ctx_t &ctx) const;
135139
status_t execute_forward_3d(const exec_ctx_t &ctx) const;
140+
status_t execute_forward_3d_dw(const exec_ctx_t &ctx) const;
136141
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
137142
const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad,
138143
const float *src_scales, const float *wei_scales) const;

src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void pick_loop_order(jit_conv_conf_t &jcp) {
4545
jcp.loop_order = loop_cwgn;
4646
if (jcp.ngroups > 1) {
4747
jcp.loop_order = loop_ngcw;
48-
if (jcp.mb < jcp.nthr)
48+
if (jcp.mb < jcp.nthr && jcp.ndims != 5)
4949
jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg;
5050
} else if (jcp.mb >= jcp.nthr && jcp.ic_without_padding <= 8) {
5151
jcp.loop_order = loop_ngcw;
@@ -418,7 +418,7 @@ void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::compute_ker_dw(int ur_w, int pad_l,
418418
};
419419

420420
auto kernel_offset = [this](int ci, int ki) {
421-
return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block);
421+
return jcp.typesize_in * ((ci * jcp.kd * jcp.kh * jcp.kw + ki) * jcp.ch_block);
422422
};
423423

424424
auto compute = [this](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
@@ -1356,9 +1356,6 @@ status_t jit_uni_x8s8s32x_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
13561356
VERBOSE_UNSUPPORTED_FEATURE,
13571357
"fused depthwise convolution does not support zero-point");
13581358

1359-
VDISPATCH_CONV_IC(!(is_3d && jcp.is_depthwise), VERBOSE_UNSUPPORTED_FEATURE,
1360-
"unsupported depthwise implementation for 3D convolution");
1361-
13621359
jcp.with_input_zp = !attr.input_zero_points_.has_default_values();
13631360
jcp.with_weights_zp = !attr.weights_zero_points_.has_default_values();
13641361

@@ -1429,7 +1426,8 @@ status_t jit_uni_x8s8s32x_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
14291426
wei_tag = with_groups ? jcp.is_depthwise ? Goihw8g : gOIhw2i8o4i
14301427
: OIhw2i8o4i;
14311428
} else {
1432-
wei_tag = with_groups ? gOIdhw2i8o4i : OIdhw2i8o4i;
1429+
wei_tag = with_groups ? jcp.is_depthwise ? Goidhw8g : gOIdhw2i8o4i
1430+
: OIdhw2i8o4i;
14331431
}
14341432
} else {
14351433
if (is_avx2) {
@@ -1444,7 +1442,9 @@ status_t jit_uni_x8s8s32x_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
14441442
? jcp.is_depthwise ? Goihw4g : gOIhw4o4i
14451443
: OIhw4o4i;
14461444
} else {
1447-
wei_tag = with_groups ? gOIdhw4o4i : OIdhw4o4i;
1445+
wei_tag = with_groups
1446+
? jcp.is_depthwise ? Goidhw4g : gOIdhw4o4i
1447+
: OIdhw4o4i;
14481448
}
14491449
}
14501450
}

0 commit comments

Comments
 (0)