@@ -682,6 +682,112 @@ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d(
682
682
return status::success;
683
683
}
684
684
685
+ status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d_dw (const exec_ctx_t &ctx) const {
686
+ auto src = CTX_IN_MEM (const char *, DNNL_ARG_SRC);
687
+ auto weights = CTX_IN_MEM (const char *, DNNL_ARG_WEIGHTS);
688
+ auto bias = CTX_IN_MEM (const char *, DNNL_ARG_BIAS);
689
+ auto dst = CTX_OUT_MEM (char *, DNNL_ARG_DST);
690
+
691
+ const memory_desc_wrapper src_d (pd ()->src_md ());
692
+ const memory_desc_wrapper dst_d (pd ()->dst_md ());
693
+ const memory_desc_wrapper weights_d (pd ()->weights_md (0 ));
694
+ const memory_desc_wrapper bias_d (pd ()->weights_md (1 ));
695
+
696
+ const size_t bia_dt_size
697
+ = pd ()->with_bias () ? types::data_type_size (bias_d.data_type ()) : 0 ;
698
+ const size_t dst_dt_size = types::data_type_size (dst_d.data_type ());
699
+
700
+ const auto &jcp = pd ()->jcp_ ;
701
+ assert (jcp.ic_block == 1 );
702
+ assert (jcp.oc_block == 1 );
703
+ assert (jcp.nb_ic == 1 );
704
+ assert (jcp.nb_oc == 1 );
705
+ assert (jcp.nb_oc_blocking == 1 );
706
+ assert (jcp.nb_ch % jcp.nb_ch_blocking == 0 );
707
+
708
+ DEFINE_ARG_SCALES_BUFFER (src_scales, DNNL_ARG_SRC);
709
+ DEFINE_ARG_SCALES_BUFFER (wei_scales, DNNL_ARG_WEIGHTS);
710
+
711
+ const float *oscales = adjust_oscales (
712
+ ctx.get_scratchpad_grantor (), src_scales, wei_scales);
713
+
714
+ size_t offset = weights_d.size () - weights_d.additional_buffer_size ();
715
+ auto w = const_cast <char *>(weights);
716
+ int32_t * compensation = (jcp.signed_input ) ? reinterpret_cast <int32_t *>(&w[offset]) :
717
+ (jcp.with_input_zp ) ? pd ()->attr ()->output_compensations_ .shifts_ : 0 ;
718
+ const uint8_t * input_zp = pd ()->attr ()->input_zero_points_ .shifts_ ;
719
+ int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking ;
720
+ int group_block = jcp.ch_block ;
721
+
722
+ parallel_nd (jcp.mb , jcp.od , jcp.oh , jcp.nb_ow , nb_groups, [&](int n, int od_s, int oh_s, int owb, int gg) {
723
+ auto p = jit_conv_call_s ();
724
+
725
+ size_t src_d_stride = src_d.blk_off (0 , 0 , 1 );
726
+ size_t wht_d_stride = wht_blk_off (weights_d, 0 , 0 , 0 , 1 );
727
+
728
+ size_t src_h_stride = src_d.blk_off (0 , 0 , 0 , 1 );
729
+ size_t wht_h_stride = wht_blk_off (weights_d, 0 , 0 , 0 , 0 , 1 );
730
+
731
+ int gb = gg * jcp.nb_ch_blocking ;
732
+ int g = gb * group_block;
733
+
734
+ int id_s = -jcp.f_pad + od_s * jcp.stride_d ;
735
+
736
+ int ih_s = -jcp.t_pad + oh_s * jcp.stride_h ;
737
+ int ow_s = owb * jcp.ow_block ;
738
+ int iw_s = ow_s * jcp.stride_w ;
739
+
740
+ auto bias_w = bias ? bias + (bias_d.blk_off (g) * bia_dt_size) : 0 ;
741
+ int32_t *compensation_w = (jcp.signed_input || jcp.with_input_zp ) ? compensation + g : 0 ;
742
+
743
+ auto dst_w = dst + dst_dt_size * dst_d.blk_off (n, g, od_s, oh_s, ow_s);
744
+ auto src_w = src + src_d.blk_off (n, g, id_s, ih_s, iw_s);
745
+ auto wht_w = weights + wht_blk_off (weights_d, gb, 0 );
746
+
747
+ auto scales = &oscales[jcp.is_oc_scale * g];
748
+
749
+ int dilate_d = jcp.dilate_d + 1 ;
750
+ int i_f_overflow = nstl::min (jcp.kd , div_up (max (0 , -id_s), dilate_d));
751
+ int i_back_overflow = nstl::min (jcp.kd ,
752
+ div_up (max (0 , id_s - jcp.id + (jcp.kd - 1 ) * dilate_d + 1 ),
753
+ dilate_d));
754
+ int kd_padding = nstl::max (0 , jcp.kd - i_f_overflow - i_back_overflow);
755
+
756
+ size_t wei_d_stride = (jcp.signed_input || jcp.with_input_zp ) ? 0 : i_f_overflow * wht_d_stride;
757
+
758
+ int dilate_h = jcp.dilate_h + 1 ;
759
+ int i_t_overflow = nstl::min (jcp.kh , div_up (max (0 , -ih_s), dilate_h));
760
+ int i_b_overflow = nstl::min (jcp.kh ,
761
+ div_up (max (0 , ih_s - jcp.ih + (jcp.kh - 1 ) * dilate_h + 1 ),
762
+ dilate_h));
763
+ int kh_padding = nstl::max (0 , jcp.kh - i_t_overflow - i_b_overflow);
764
+
765
+ size_t wei_h_stride = (jcp.signed_input || jcp.with_input_zp ) ? 0 : i_t_overflow * wht_h_stride;
766
+ p.src = src_w + i_t_overflow * dilate_h * src_h_stride
767
+ + i_f_overflow * dilate_d * src_d_stride;
768
+ p.dst = dst_w;
769
+ p.filt = wht_w + wei_d_stride + wei_h_stride;
770
+ p.bias = bias_w;
771
+ p.compensation = compensation_w;
772
+ p.oc_blocks = gb;
773
+ p.kd_padding = kd_padding;
774
+ p.kh_padding = kh_padding;
775
+ p.scales = scales;
776
+ p.f_overflow = i_f_overflow;
777
+ p.back_overflow = i_back_overflow;
778
+ p.t_overflow = i_t_overflow;
779
+ p.b_overflow = i_b_overflow;
780
+ p.owb = owb;
781
+
782
+ p.oc_off = g * sizeof (float );
783
+ if (jcp.with_input_zp )
784
+ p.input_zp = input_zp + g;
785
+
786
+ (*kernel_)(&p);
787
+ });
788
+ return status::success;
789
+ }
790
+
685
791
} // namespace x64
686
792
} // namespace cpu
687
793
} // namespace impl
0 commit comments