@@ -3270,7 +3270,7 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_xf16(
3270
3270
}
3271
3271
3272
3272
const int eltwise_ind = p.find (primitive_kind::eltwise);
3273
- if (eltwise_ind != -1 ) eltwise_injector_-> compute_vector (zmm_out.getIdx ());
3273
+ if (eltwise_ind != -1 ) idx_to_eltwise_injector_. at (eltwise_ind). compute_vector (zmm_out.getIdx ());
3274
3274
3275
3275
const Ymm ymm_out = Ymm (zmm_out.getIdx ());
3276
3276
const Ymm ymm_out_k = ymm_mask (ymm_out, mask_flag, true );
@@ -3330,7 +3330,7 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8(
3330
3330
if (jcp.with_bias ) vaddps (zmm_out, zmm_out, zmm_bias);
3331
3331
3332
3332
/* Do post-ops */
3333
- if (maybe_eltwise (0 )) eltwise_injector_-> compute_vector (zmm_out.getIdx ());
3333
+ if (maybe_eltwise (0 )) idx_to_eltwise_injector_. at ( 0 ). compute_vector (zmm_out.getIdx ());
3334
3334
if (p_sum_scale) { // post_op: sum
3335
3335
cvt2ps (jcp.dsrc_dt , zmm_prev_dst, addr, mask_flag);
3336
3336
if (*p_sum_zp != 0 ) {
@@ -3342,7 +3342,12 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8(
3342
3342
else
3343
3343
vfmadd231ps (zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
3344
3344
}
3345
- if (maybe_eltwise (1 )) eltwise_injector_->compute_vector (zmm_out.getIdx ());
3345
+
3346
+ if (maybe_eltwise (1 )) idx_to_eltwise_injector_.at (1 ).compute_vector (zmm_out.getIdx ());
3347
+ for (auto i = 2 ; i < jcp.post_ops .len (); i++) {
3348
+ if (idx_to_eltwise_injector_.count (i) != 0 )
3349
+ idx_to_eltwise_injector_.at (i).compute_vector (zmm_out.getIdx ());
3350
+ }
3346
3351
3347
3352
if (jcp.dst_scale ) { vmulps (zmm_out_msk, zmm_out, zmm_dst_scale); }
3348
3353
@@ -3672,7 +3677,10 @@ void jit_avx512_core_amx_bwd_data_kernel_t::generate() {
3672
3677
3673
3678
postamble ();
3674
3679
3675
- if (jcp.with_eltwise ) eltwise_injector_->prepare_table ();
3680
+ if (jcp.with_eltwise ) {
3681
+ for (auto &elt_injector : idx_to_eltwise_injector_)
3682
+ elt_injector.second .prepare_table ();
3683
+ }
3676
3684
}
3677
3685
3678
3686
bool jit_avx512_core_amx_bwd_data_kernel_t::post_ops_ok (
@@ -3690,6 +3698,12 @@ bool jit_avx512_core_amx_bwd_data_kernel_t::post_ops_ok(
3690
3698
else
3691
3699
return p.contain (sum, idx);
3692
3700
};
3701
+ // Add more element-wise post-ops supported for int8 deconv.
3702
+ bool all_eltwise = jcp.is_int8_deconvolution ;
3703
+ for (auto i = 0 ; i < p.len (); i++)
3704
+ all_eltwise &= is_eltwise (i);
3705
+ if (all_eltwise)
3706
+ return true ;
3693
3707
3694
3708
switch (p.len ()) {
3695
3709
case 0 : return true ;
@@ -3775,6 +3789,7 @@ status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
3775
3789
3776
3790
jcp = zero<decltype (jcp)>();
3777
3791
jcp.isa = is_f16 ? avx512_core_amx_fp16 : avx512_core_amx;
3792
+ jcp.is_int8_deconvolution = is_int8_deconvolution;
3778
3793
jcp.ndims = ndims;
3779
3794
jcp.prop_kind = cd.prop_kind ;
3780
3795
jcp.ngroups = with_groups ? weights_d.dims ()[0 ] : 1 ;
@@ -3862,7 +3877,7 @@ status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
3862
3877
VERBOSE_UNSUPPORTED_TAG_S, " src" );
3863
3878
3864
3879
jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
3865
- assert (IMPLICATION (is_int8_deconvolution, jcp.is_nspc ));
3880
+ assert (IMPLICATION (jcp. is_int8_deconvolution , jcp.is_nspc ));
3866
3881
3867
3882
// TODO: remove all support for nChw16c from this implementation
3868
3883
VDISPATCH_CONV_IC (jcp.is_nspc , VERBOSE_UNSUPPORTED_TAG_S, " src" );
@@ -3901,7 +3916,7 @@ status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
3901
3916
const auto &p = attr.post_ops_ ;
3902
3917
const int eltwise_ind = p.find (primitive_kind::eltwise);
3903
3918
jcp.with_eltwise = eltwise_ind != -1 ;
3904
- if ( jcp.with_eltwise ) jcp. eltwise = p. entry_ [eltwise_ind]. eltwise ;
3919
+ jcp.post_ops = p;
3905
3920
3906
3921
auto set_or_check_wei_format = [&]() {
3907
3922
using namespace format_tag ;
@@ -3914,7 +3929,7 @@ status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
3914
3929
wei_tag = pick (with_groups + 2 * (ndims - 3 ), OIw16i16o2i,
3915
3930
gOIw16i16o2i , OIhw16i16o2i, gOIhw16i16o2i , OIdhw16i16o2i,
3916
3931
gOIdhw16i16o2i );
3917
- else if (is_int8_deconvolution)
3932
+ else if (jcp. is_int8_deconvolution )
3918
3933
wei_tag = pick (with_groups + 2 * (ndims - 3 ), OIw16i16o4i,
3919
3934
gOIw16i16o4i , OIhw16i16o4i, gOIhw16i16o4i , OIdhw16i16o4i,
3920
3935
gOIdhw16i16o4i );
0 commit comments