@@ -467,6 +467,8 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc(
467
467
// threads share work across mini-batch and groups
468
468
const dim_t work_amount = jcp.ngroups * jcp.mb ;
469
469
470
+ const auto &p = pd ()->attr ()->post_ops_ ;
471
+
470
472
data_t *__restrict col = scratchpad.get <data_t >(key_conv_gemm_col)
471
473
+ (ptrdiff_t )ithr * jcp.im2col_sz ;
472
474
const bool acc_needed = jcp.ngroups > 1 ;
@@ -515,6 +517,25 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc(
515
517
}
516
518
});
517
519
}
520
+ if (p.len () > 0 ) {
521
+ int depthwise_inj_idx = 0 ;
522
+ for (int i = 0 ; i < p.len (); i++) {
523
+ auto &post_op = p.entry_ [i];
524
+ if (post_op.is_depthwise ()) {
525
+ auto depthwise_weights = post_op.depthwise .weights_data ;
526
+ auto depthwise_bias = post_op.depthwise .biases_data ;
527
+ parallel_nd (static_cast <size_t >(jcp.is ) * jcp.id , [&](size_t is) {
528
+ data_t *__restrict diff_src_arr
529
+ = diff_src + is * diff_src_os_stride;
530
+ for (int ic = 0 ; ic < jcp.ic ; ic++) {
531
+ diff_src_arr[ic] = depthwise_injectors[depthwise_inj_idx]->compute_scalar (diff_src_arr[ic],
532
+ depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic);
533
+ }
534
+ });
535
+ depthwise_inj_idx++;
536
+ }
537
+ }
538
+ }
518
539
nd_iterator_step (n, jcp.mb , g, jcp.ngroups );
519
540
}
520
541
return status::success;
@@ -547,6 +568,8 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp(
547
568
const dim_t work_amount = (size_t )jcp.ngroups * jcp.mb ;
548
569
const bool is_problem_3d = pd ()->ndims () == 5 ;
549
570
571
+ const auto &p = pd ()->attr ()->post_ops_ ;
572
+
550
573
std::atomic<status_t > st (status::success);
551
574
parallel (jcp.nthr , [&](const int ithr, const int nthr) {
552
575
data_t *_col = col + (ptrdiff_t )ithr * jcp.im2col_sz ;
@@ -594,6 +617,26 @@ status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp(
594
617
}
595
618
}
596
619
}
620
+ if (p.len () > 0 ) {
621
+ int depthwise_inj_idx = 0 ;
622
+ for (int i = 0 ; i < p.len (); i++) {
623
+ auto &post_op = p.entry_ [i];
624
+ if (post_op.is_depthwise ()) {
625
+ auto depthwise_weights = post_op.depthwise .weights_data ;
626
+ auto depthwise_bias = post_op.depthwise .biases_data ;
627
+ parallel_nd (jcp.ic , [&](const int ic) {
628
+ for (int id = 0 ; id < jcp.id ; ++id) {
629
+ data_t *d_ = _diff_src + ic * jcp.id * jcp.is + id * jcp.is ;
630
+ for (int iS = 0 ; iS < jcp.is ; ++iS) {
631
+ d_[iS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar (d_[iS],
632
+ depthwise_weights + g * jcp.ic + ic, depthwise_bias + g * jcp.ic + ic);
633
+ }
634
+ }
635
+ });
636
+ depthwise_inj_idx++;
637
+ }
638
+ }
639
+ }
597
640
nd_iterator_step (g, jcp.ngroups , n, jcp.mb );
598
641
}
599
642
});
0 commit comments