@@ -207,6 +207,10 @@ inline int dnnl_get_current_num_threads() {
207
207
#define simdlen (x )
208
208
#endif // long simdlen if
209
209
210
+ #if defined(DNNL_ENABLE_ITT_TASKS)
211
+ #include " common/ittnotify.hpp"
212
+ #endif
213
+
210
214
namespace dnnl {
211
215
namespace impl {
212
216
@@ -655,6 +659,171 @@ static inline void parallel_nd(dim_t D0, dim_t D1, dim_t D2, dim_t D3, dim_t D4,
655
659
});
656
660
}
657
661
662
+ template <typename F>
663
+ void parallel_legacy (int nthr, F f) {
664
+ nthr = adjust_num_threads (nthr, INT64_MAX);
665
+ #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ
666
+ assert (nthr == 1 );
667
+ f (0 , 1 );
668
+ #else
669
+ #if defined(DNNL_ENABLE_ITT_TASKS)
670
+ auto task_primitive_kind = itt::primitive_task_get_current_kind ();
671
+ bool itt_enable = itt::get_itt (itt::__itt_task_level_high);
672
+ #endif
673
+ if (nthr == 1 ) {
674
+ f (0 , 1 );
675
+ return ;
676
+ }
677
+ #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
678
+ #pragma omp parallel num_threads(nthr)
679
+ {
680
+ int nthr_ = omp_get_num_threads ();
681
+ int ithr_ = omp_get_thread_num ();
682
+ assert (nthr_ == nthr);
683
+ #if defined(DNNL_ENABLE_ITT_TASKS)
684
+ if (ithr_ && itt_enable) itt::primitive_task_start (task_primitive_kind);
685
+ #endif
686
+ f (ithr_, nthr_);
687
+ #if defined(DNNL_ENABLE_ITT_TASKS)
688
+ if (ithr_ && itt_enable) itt::primitive_task_end ();
689
+ #endif
690
+ }
691
+ #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB
692
+ tbb::parallel_for (
693
+ 0 , nthr,
694
+ [&](int ithr) {
695
+ #if defined(DNNL_ENABLE_ITT_TASKS)
696
+ bool mark_task = itt::primitive_task_get_current_kind ()
697
+ == primitive_kind::undefined;
698
+ if (mark_task && itt_enable)
699
+ itt::primitive_task_start (task_primitive_kind);
700
+ #endif
701
+ f (ithr, nthr);
702
+ #if defined(DNNL_ENABLE_ITT_TASKS)
703
+ if (mark_task && itt_enable) itt::primitive_task_end ();
704
+ #endif
705
+ },
706
+ tbb::static_partitioner ());
707
+ #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB_AUTO
708
+ tbb::parallel_for (
709
+ 0 , nthr, [&](int ithr) { f (ithr, nthr); });
710
+ #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
711
+ using namespace dnnl ::impl::threadpool_utils;
712
+ dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool ();
713
+ if (!tp || dnnl_in_parallel ()) {
714
+ threadpool_utils::deactivate_threadpool ();
715
+ for (int ithr = 0 ; ithr < nthr; ithr++) {
716
+ f (ithr, nthr);
717
+ }
718
+ threadpool_utils::activate_threadpool (tp);
719
+ } else {
720
+ bool async = tp->get_flags ()
721
+ & dnnl::threadpool_interop::threadpool_iface::ASYNCHRONOUS;
722
+ counting_barrier_t b;
723
+ if (async) b.init (nthr);
724
+ tp->parallel_for (nthr, [&, tp](int ithr, int nthr) {
725
+ bool is_master = threadpool_utils::get_active_threadpool () == tp;
726
+ if (!is_master) {
727
+ threadpool_utils::activate_threadpool (tp);
728
+ #if defined(DNNL_ENABLE_ITT_TASKS)
729
+ if (itt_enable) itt::primitive_task_start (task_primitive_kind);
730
+ #endif
731
+ }
732
+ f (ithr, nthr);
733
+ if (!is_master) {
734
+ #if defined(DNNL_ENABLE_ITT_TASKS)
735
+ if (itt_enable) itt::primitive_task_end ();
736
+ #endif
737
+ threadpool_utils::deactivate_threadpool ();
738
+ }
739
+ if (async) b.notify ();
740
+ });
741
+ if (async) b.wait ();
742
+ }
743
+ #endif
744
+ #endif
745
+ }
746
+
747
+ template <typename T0, typename F>
748
+ void for_nd_legacy (const int ithr, const int nthr, const T0 &D0, F f) {
749
+ T0 start {0 }, end {0 };
750
+ balance211 (D0, nthr, ithr, start, end);
751
+ for (T0 d0 = start; d0 < end; ++d0)
752
+ f (d0);
753
+ }
754
+
755
+ template <typename T0, typename T1, typename T2, typename T3, typename F>
756
+ void for_nd_legacy (const int ithr, const int nthr, const T0 &D0, const T1 &D1,
757
+ const T2 &D2, const T3 &D3, F f) {
758
+ const size_t work_amount = (size_t )D0 * D1 * D2 * D3;
759
+ if (work_amount == 0 ) return ;
760
+ size_t start {0 }, end {0 };
761
+ balance211 (work_amount, nthr, ithr, start, end);
762
+
763
+ T0 d0 {0 };
764
+ T1 d1 {0 };
765
+ T2 d2 {0 };
766
+ T3 d3 {0 };
767
+ utils::nd_iterator_init (start, d0, D0, d1, D1, d2, D2, d3, D3);
768
+ for (size_t iwork = start; iwork < end; ++iwork) {
769
+ f (d0, d1, d2, d3);
770
+ utils::nd_iterator_step (d0, D0, d1, D1, d2, D2, d3, D3);
771
+ }
772
+ }
773
+
774
+ template <typename T0, typename T1, typename T2, typename T3, typename T4,
775
+ typename T5, typename F>
776
+ void for_nd_legacy (const int ithr, const int nthr, const T0 &D0, const T1 &D1,
777
+ const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
778
+ const size_t work_amount = (size_t )D0 * D1 * D2 * D3 * D4 * D5;
779
+ if (work_amount == 0 ) return ;
780
+ size_t start {0 }, end {0 };
781
+ balance211 (work_amount, nthr, ithr, start, end);
782
+
783
+ T0 d0 {0 };
784
+ T1 d1 {0 };
785
+ T2 d2 {0 };
786
+ T3 d3 {0 };
787
+ T4 d4 {0 };
788
+ T5 d5 {0 };
789
+ utils::nd_iterator_init (
790
+ start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
791
+ for (size_t iwork = start; iwork < end; ++iwork) {
792
+ f (d0, d1, d2, d3, d4, d5);
793
+ utils::nd_iterator_step (d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
794
+ }
795
+ }
796
+
797
+ template <typename T0, typename F>
798
+ void parallel_nd_legacy (const T0 &D0, F f) {
799
+ const size_t work_amount = (size_t )D0;
800
+ int nthr = adjust_num_threads (dnnl_get_current_num_threads (), work_amount);
801
+ if (nthr)
802
+ parallel_legacy (nthr, [&](int ithr, int nthr) { for_nd_legacy (ithr, nthr, D0, f); });
803
+ }
804
+
805
+ template <typename T0, typename T1, typename T2, typename T3, typename F>
806
+ void parallel_nd_legacy (const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
807
+ const size_t work_amount = (size_t )D0 * D1 * D2 * D3;
808
+ int nthr = adjust_num_threads (dnnl_get_current_num_threads (), work_amount);
809
+ if (nthr)
810
+ parallel_legacy (nthr, [&](int ithr, int nthr) {
811
+ for_nd_legacy (ithr, nthr, D0, D1, D2, D3, f);
812
+ });
813
+ }
814
+
815
+ template <typename T0, typename T1, typename T2, typename T3, typename T4,
816
+ typename T5, typename F>
817
+ void parallel_nd_legacy (const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
818
+ const T4 &D4, const T5 &D5, F f) {
819
+ const size_t work_amount = (size_t )D0 * D1 * D2 * D3 * D4 * D5;
820
+ int nthr = adjust_num_threads (dnnl_get_current_num_threads (), work_amount);
821
+ if (nthr)
822
+ parallel_legacy (nthr, [&](int ithr, int nthr) {
823
+ for_nd_legacy (ithr, nthr, D0, D1, D2, D3, D4, D5, f);
824
+ });
825
+ }
826
+
658
827
} // namespace impl
659
828
} // namespace dnnl
660
829
0 commit comments