@@ -502,7 +502,8 @@ void subarray_sum(size_t num_arrs, float *output, size_t nelems,
502
502
const size_t blocks_number = nelems / block_size;
503
503
const size_t tail = nelems % block_size;
504
504
505
- PRAGMA_OMP (parallel) {
505
+ PRAGMA_OMP (parallel)
506
+ {
506
507
const int ithr = dnnl_get_thread_num ();
507
508
const int nthr = dnnl_get_num_threads ();
508
509
size_t start {0 }, end {0 };
@@ -583,7 +584,8 @@ void array_sum(size_t num_arrs, float *output, size_t nelems,
583
584
const size_t blocks_number = nelems / block_size;
584
585
const size_t tail = nelems % block_size;
585
586
586
- PRAGMA_OMP (parallel) {
587
+ PRAGMA_OMP (parallel)
588
+ {
587
589
const size_t ithr = dnnl_get_thread_num ();
588
590
const size_t nthr = dnnl_get_num_threads ();
589
591
size_t start {0 }, end {0 };
@@ -672,7 +674,8 @@ void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t::
672
674
0 .179271708683473f , 0 .403361344537815f , 1 .13777777777778f };
673
675
float G_O_3x3_4x4[4 ] = {2 .25f , 0 .625f , 1 .5f , 0 .390625f };
674
676
675
- PRAGMA_OMP (parallel num_threads (nthreads) firstprivate (trans_ker_p, I, T)) {
677
+ PRAGMA_OMP (parallel num_threads (nthreads) firstprivate (trans_ker_p, I, T))
678
+ {
676
679
if (jcp.with_bias ) {
677
680
parallel_nd_in_omp (
678
681
nthreads, jcp.oc / simd_w, [&](int ithr, int ofm) {
@@ -687,69 +690,76 @@ void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t::
687
690
int ithr = dnnl_get_thread_num ();
688
691
for (int ifm1 = 0 ; ifm1 < jcp.nb_ic ; ++ifm1) {
689
692
int first_tblk = 0 ;
690
- PRAGMA_OMP (for )
691
- for (int tblk1 = 0 ; tblk1 < jcp.tile_block ; ++tblk1) {
692
- int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur ;
693
- int img = tile_index / (jcp.itiles * jcp.jtiles );
694
- trans_ker_p.ti = tile_index % jcp.itiles ;
695
- trans_ker_p.tj = (tile_index / jcp.itiles ) % jcp.jtiles ;
696
- trans_ker_p.M = I;
697
- trans_ker_p.T = T;
698
- trans_ker_p.G = G_I_3x3_4x4;
699
- for (int ifm2 = 0 ; ifm2 < jcp.ic_block ; ++ifm2) {
700
- int ifm = ifm1 * jcp.ic_block + ifm2;
701
- trans_ker_p.src = (float *)&(src (img, ifm, 0 , 0 , 0 ));
702
- trans_ker_p.dst = (float *)&(V (ithr, 0 , 0 , ifm2, 0 , 0 , 0 ));
703
- kernel_->src_transform (&trans_ker_p);
704
- }
705
-
706
- for (int ofm1 = 0 ; ofm1 < jcp.nb_oc ; ++ofm1) {
707
- trans_ker_p.G = G_W_3x3_4x4;
708
- for (int ofm2 = 0 ; ofm2 < jcp.oc_block ; ++ofm2) {
709
- int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block ;
710
- trans_ker_p.src = (float *)&(diff_dst (img, ofm, 0 , 0 , 0 ));
711
- trans_ker_p.dst = (float *)&(M (ithr, 0 , 0 , ofm2, 0 , 0 , 0 , 0 ));
712
- if (jcp.with_bias && ifm1 == 0 ) {
713
- trans_ker_p.bias
714
- = (float *)&(diff_bias_prv (ithr, ofm * simd_w));
715
- kernel_->diff_dst_transform_wbias (&trans_ker_p);
716
- } else {
717
- kernel_->diff_dst_transform (&trans_ker_p);
718
- }
719
- }
720
-
721
- for (int oj = 0 ; oj < alpha; ++oj) {
722
- for (int oi = 0 ; oi < alpha; ++oi) {
723
- kernel_->gemm_loop_ker_first_iter (
724
- &(Us (ithr, oj, oi, 0 , 0 , 0 , 0 , 0 )),
725
- &(M (ithr, oj, oi, 0 , 0 , 0 , 0 , 0 )),
726
- &(V (ithr, oj, oi, 0 , 0 , 0 , 0 )));
727
- }
728
- }
729
- trans_ker_p.G = G_O_3x3_4x4;
730
- for (int ofm2 = 0 ; ofm2 < jcp.oc_block ; ++ofm2) {
731
- for (int ofm3 = 0 ; ofm3 < jcp.oc_reg_block ; ++ofm3) {
732
- int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block
733
- + ofm3;
693
+ PRAGMA_OMP (for )
694
+ for (int tblk1 = 0 ; tblk1 < jcp.tile_block ; ++tblk1) {
695
+ int tile_index
696
+ = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur ;
697
+ int img = tile_index / (jcp.itiles * jcp.jtiles );
698
+ trans_ker_p.ti = tile_index % jcp.itiles ;
699
+ trans_ker_p.tj = (tile_index / jcp.itiles ) % jcp.jtiles ;
700
+ trans_ker_p.M = I;
701
+ trans_ker_p.T = T;
702
+ trans_ker_p.G = G_I_3x3_4x4;
734
703
for (int ifm2 = 0 ; ifm2 < jcp.ic_block ; ++ifm2) {
735
704
int ifm = ifm1 * jcp.ic_block + ifm2;
736
- trans_ker_p.src = (float *)&(
737
- Us (ithr, 0 , 0 , ofm2, ifm2, 0 , ofm3, 0 ));
738
- trans_ker_p.dst = (float *)&(
739
- diff_weights_prv (ithr, ofm, ifm, 0 , 0 , 0 , 0 ));
740
- if (first_tblk == 0 ) {
741
- kernel_->diff_weights_transform (&trans_ker_p);
742
- } else {
743
- kernel_->diff_weights_transform_accum (&trans_ker_p);
705
+ trans_ker_p.src = (float *)&(src (img, ifm, 0 , 0 , 0 ));
706
+ trans_ker_p.dst = (float *)&(V (ithr, 0 , 0 , ifm2, 0 , 0 , 0 ));
707
+ kernel_->src_transform (&trans_ker_p);
708
+ }
709
+
710
+ for (int ofm1 = 0 ; ofm1 < jcp.nb_oc ; ++ofm1) {
711
+ trans_ker_p.G = G_W_3x3_4x4;
712
+ for (int ofm2 = 0 ; ofm2 < jcp.oc_block ; ++ofm2) {
713
+ int ofm = (ofm1 * jcp.oc_block + ofm2)
714
+ * jcp.oc_reg_block ;
715
+ trans_ker_p.src
716
+ = (float *)&(diff_dst (img, ofm, 0 , 0 , 0 ));
717
+ trans_ker_p.dst
718
+ = (float *)&(M (ithr, 0 , 0 , ofm2, 0 , 0 , 0 , 0 ));
719
+ if (jcp.with_bias && ifm1 == 0 ) {
720
+ trans_ker_p.bias = (float *)&(
721
+ diff_bias_prv (ithr, ofm * simd_w));
722
+ kernel_->diff_dst_transform_wbias (&trans_ker_p);
723
+ } else {
724
+ kernel_->diff_dst_transform (&trans_ker_p);
725
+ }
726
+ }
727
+
728
+ for (int oj = 0 ; oj < alpha; ++oj) {
729
+ for (int oi = 0 ; oi < alpha; ++oi) {
730
+ kernel_->gemm_loop_ker_first_iter (
731
+ &(Us (ithr, oj, oi, 0 , 0 , 0 , 0 , 0 )),
732
+ &(M (ithr, oj, oi, 0 , 0 , 0 , 0 , 0 )),
733
+ &(V (ithr, oj, oi, 0 , 0 , 0 , 0 )));
734
+ }
735
+ }
736
+ trans_ker_p.G = G_O_3x3_4x4;
737
+ for (int ofm2 = 0 ; ofm2 < jcp.oc_block ; ++ofm2) {
738
+ for (int ofm3 = 0 ; ofm3 < jcp.oc_reg_block ; ++ofm3) {
739
+ int ofm = (ofm1 * jcp.oc_block + ofm2)
740
+ * jcp.oc_reg_block
741
+ + ofm3;
742
+ for (int ifm2 = 0 ; ifm2 < jcp.ic_block ; ++ifm2) {
743
+ int ifm = ifm1 * jcp.ic_block + ifm2;
744
+ trans_ker_p.src = (float *)&(
745
+ Us (ithr, 0 , 0 , ofm2, ifm2, 0 , ofm3, 0 ));
746
+ trans_ker_p.dst = (float *)&(diff_weights_prv (
747
+ ithr, ofm, ifm, 0 , 0 , 0 , 0 ));
748
+ if (first_tblk == 0 ) {
749
+ kernel_->diff_weights_transform (
750
+ &trans_ker_p);
751
+ } else {
752
+ kernel_->diff_weights_transform_accum (
753
+ &trans_ker_p);
754
+ }
755
+ }
756
+ }
744
757
}
745
758
}
759
+ ++first_tblk;
746
760
}
747
761
}
748
762
}
749
- ++first_tblk;
750
- }
751
- }
752
- }
753
763
754
764
// Reduce diff-weights
755
765
{
@@ -826,7 +836,8 @@ void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t::
826
836
float I[alpha][alpha][simd_w];
827
837
float T[alpha][alpha][simd_w];
828
838
829
- PRAGMA_OMP (parallel firstprivate (first_tblk, trans_ker_p, I, T)) {
839
+ PRAGMA_OMP (parallel firstprivate (first_tblk, trans_ker_p, I, T))
840
+ {
830
841
if (jcp.with_bias ) {
831
842
parallel_nd_in_omp (nthreads, jcp.oc , [&](int ithr, int ofm) {
832
843
diff_bias_prv (ithr, ofm) = 0 .0f ;
@@ -923,7 +934,8 @@ void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t::
923
934
}
924
935
925
936
trans_ker_p.G = G_O_3x3_4x4;
926
- PRAGMA_OMP (parallel firstprivate (trans_ker_p)) {
937
+ PRAGMA_OMP (parallel firstprivate (trans_ker_p))
938
+ {
927
939
parallel_nd_in_omp (jcp.nb_ic , jcp.nb_oc , jcp.oc_block , jcp.ic_block ,
928
940
jcp.oc_reg_block ,
929
941
[&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3) {
0 commit comments