Skip to content

Commit b69ecc3

Browse files
luweizhou2016azhai219
authored andcommitted
[FIX] Extend AMX deconv to support oscale+eltwise+eltwise post ops.
1 parent 49a5dfa commit b69ecc3

3 files changed

+33
-13
lines changed

src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp

+22-7
Original file line numberDiff line numberDiff line change
@@ -3270,7 +3270,7 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_xf16(
32703270
}
32713271

32723272
const int eltwise_ind = p.find(primitive_kind::eltwise);
3273-
if (eltwise_ind != -1) eltwise_injector_->compute_vector(zmm_out.getIdx());
3273+
if (eltwise_ind != -1) idx_to_eltwise_injector_.at(eltwise_ind).compute_vector(zmm_out.getIdx());
32743274

32753275
const Ymm ymm_out = Ymm(zmm_out.getIdx());
32763276
const Ymm ymm_out_k = ymm_mask(ymm_out, mask_flag, true);
@@ -3330,7 +3330,7 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8(
33303330
if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias);
33313331

33323332
/* Do post-ops */
3333-
if (maybe_eltwise(0)) eltwise_injector_->compute_vector(zmm_out.getIdx());
3333+
if (maybe_eltwise(0)) idx_to_eltwise_injector_.at(0).compute_vector(zmm_out.getIdx());
33343334
if (p_sum_scale) { // post_op: sum
33353335
cvt2ps(jcp.dsrc_dt, zmm_prev_dst, addr, mask_flag);
33363336
if (*p_sum_zp != 0) {
@@ -3342,7 +3342,12 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8(
33423342
else
33433343
vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
33443344
}
3345-
if (maybe_eltwise(1)) eltwise_injector_->compute_vector(zmm_out.getIdx());
3345+
3346+
if (maybe_eltwise(1)) idx_to_eltwise_injector_.at(1).compute_vector(zmm_out.getIdx());
3347+
for (auto i = 2; i < jcp.post_ops.len(); i++) {
3348+
if (idx_to_eltwise_injector_.count(i) != 0)
3349+
idx_to_eltwise_injector_.at(i).compute_vector(zmm_out.getIdx());
3350+
}
33463351

33473352
if (jcp.dst_scale) { vmulps(zmm_out_msk, zmm_out, zmm_dst_scale); }
33483353

@@ -3672,7 +3677,10 @@ void jit_avx512_core_amx_bwd_data_kernel_t::generate() {
36723677

36733678
postamble();
36743679

3675-
if (jcp.with_eltwise) eltwise_injector_->prepare_table();
3680+
if (jcp.with_eltwise) {
3681+
for (auto &elt_injector : idx_to_eltwise_injector_)
3682+
elt_injector.second.prepare_table();
3683+
}
36763684
}
36773685

36783686
bool jit_avx512_core_amx_bwd_data_kernel_t::post_ops_ok(
@@ -3690,6 +3698,12 @@ bool jit_avx512_core_amx_bwd_data_kernel_t::post_ops_ok(
36903698
else
36913699
return p.contain(sum, idx);
36923700
};
3701+
//Add more element-wise post-ops supported for int8 deconv.
3702+
bool all_eltwise = jcp.is_int8_deconvolution;
3703+
for (auto i = 0; i < p.len(); i++)
3704+
all_eltwise &= is_eltwise(i);
3705+
if (all_eltwise)
3706+
return true;
36933707

36943708
switch (p.len()) {
36953709
case 0: return true;
@@ -3775,6 +3789,7 @@ status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
37753789

37763790
jcp = zero<decltype(jcp)>();
37773791
jcp.isa = is_f16 ? avx512_core_amx_fp16 : avx512_core_amx;
3792+
jcp.is_int8_deconvolution = is_int8_deconvolution;
37783793
jcp.ndims = ndims;
37793794
jcp.prop_kind = cd.prop_kind;
37803795
jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
@@ -3862,7 +3877,7 @@ status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
38623877
VERBOSE_UNSUPPORTED_TAG_S, "src");
38633878

38643879
jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
3865-
assert(IMPLICATION(is_int8_deconvolution, jcp.is_nspc));
3880+
assert(IMPLICATION(jcp.is_int8_deconvolution, jcp.is_nspc));
38663881

38673882
// TODO: remove all support for nChw16c from this implementation
38683883
VDISPATCH_CONV_IC(jcp.is_nspc, VERBOSE_UNSUPPORTED_TAG_S, "src");
@@ -3901,7 +3916,7 @@ status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
39013916
const auto &p = attr.post_ops_;
39023917
const int eltwise_ind = p.find(primitive_kind::eltwise);
39033918
jcp.with_eltwise = eltwise_ind != -1;
3904-
if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
3919+
jcp.post_ops = p;
39053920

39063921
auto set_or_check_wei_format = [&]() {
39073922
using namespace format_tag;
@@ -3914,7 +3929,7 @@ status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
39143929
wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16i16o2i,
39153930
gOIw16i16o2i, OIhw16i16o2i, gOIhw16i16o2i, OIdhw16i16o2i,
39163931
gOIdhw16i16o2i);
3917-
else if (is_int8_deconvolution)
3932+
else if (jcp.is_int8_deconvolution)
39183933
wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16i16o4i,
39193934
gOIw16i16o4i, OIhw16i16o4i, gOIhw16i16o4i, OIdhw16i16o4i,
39203935
gOIdhw16i16o4i);

src/cpu/x64/jit_avx512_core_amx_conv_kernel.hpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -482,12 +482,15 @@ struct jit_avx512_core_amx_bwd_data_kernel_t : public jit_generator {
482482
: jit_generator(jit_name(), avx512_core_amx)
483483
, jcp(ajcp)
484484
, attr_(attr)
485-
, eltwise_injector_(nullptr)
486485
, bwd_data_copy_kernel_(nullptr) {
487-
if (jcp.with_eltwise)
488-
eltwise_injector_
489-
= utils::make_unique<jit_uni_eltwise_injector<avx512_core>>(
490-
this, jcp.eltwise);
486+
if (jcp.with_eltwise) {
487+
for (int i = 0; i < jcp.post_ops.len(); i++) {
488+
const auto post_op = jcp.post_ops.entry_[i];
489+
if (post_op.is_eltwise())
490+
idx_to_eltwise_injector_.emplace(i,
491+
jit_uni_eltwise_injector<avx512_core>(this, post_op.eltwise));
492+
}
493+
}
491494
bwd_data_copy_kernel_ = utils::make_unique<
492495
jit_avx512_core_amx_bwd_data_copy_kernel_t>(jcp);
493496
}
@@ -518,7 +521,8 @@ struct jit_avx512_core_amx_bwd_data_kernel_t : public jit_generator {
518521
}
519522

520523
private:
521-
std::unique_ptr<jit_uni_eltwise_injector<avx512_core>> eltwise_injector_;
524+
std::map<int, jit_uni_eltwise_injector<avx512_core>>
525+
idx_to_eltwise_injector_;
522526
std::unique_ptr<jit_avx512_core_amx_bwd_data_copy_kernel_t>
523527
bwd_data_copy_kernel_;
524528

src/cpu/x64/jit_primitive_conf.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ struct jit_conv_conf_t {
267267

268268
int dw_conv_oh, dw_conv_ow;
269269
data_type_t dw_conv_dst_dt;
270+
bool is_int8_deconvolution;
270271
};
271272

272273
// calculates filter size taking into account dilation

0 commit comments

Comments
 (0)