@@ -207,6 +207,10 @@ inline int dnnl_get_current_num_threads() {
207
207
#define simdlen (x )
208
208
#endif // long simdlen if
209
209
210
+ #if defined(DNNL_ENABLE_ITT_TASKS)
211
+ #include " common/ittnotify.hpp"
212
+ #endif
213
+
210
214
namespace dnnl {
211
215
namespace impl {
212
216
@@ -674,6 +678,171 @@ void parallel_nd_in_omp(Args &&...args) {
674
678
#endif
675
679
}
676
680
681
+ template <typename F>
682
+ void parallel_legacy (int nthr, F f) {
683
+ nthr = adjust_num_threads (nthr, INT64_MAX);
684
+ #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ
685
+ assert (nthr == 1 );
686
+ f (0 , 1 );
687
+ #else
688
+ #if defined(DNNL_ENABLE_ITT_TASKS)
689
+ auto task_primitive_kind = itt::primitive_task_get_current_kind ();
690
+ bool itt_enable = itt::get_itt (itt::__itt_task_level_high);
691
+ #endif
692
+ if (nthr == 1 ) {
693
+ f (0 , 1 );
694
+ return ;
695
+ }
696
+ #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
697
+ #pragma omp parallel num_threads(nthr)
698
+ {
699
+ int nthr_ = omp_get_num_threads ();
700
+ int ithr_ = omp_get_thread_num ();
701
+ assert (nthr_ == nthr);
702
+ #if defined(DNNL_ENABLE_ITT_TASKS)
703
+ if (ithr_ && itt_enable) itt::primitive_task_start (task_primitive_kind);
704
+ #endif
705
+ f (ithr_, nthr_);
706
+ #if defined(DNNL_ENABLE_ITT_TASKS)
707
+ if (ithr_ && itt_enable) itt::primitive_task_end ();
708
+ #endif
709
+ }
710
+ #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB
711
+ tbb::parallel_for (
712
+ 0 , nthr,
713
+ [&](int ithr) {
714
+ #if defined(DNNL_ENABLE_ITT_TASKS)
715
+ bool mark_task = itt::primitive_task_get_current_kind ()
716
+ == primitive_kind::undefined;
717
+ if (mark_task && itt_enable)
718
+ itt::primitive_task_start (task_primitive_kind);
719
+ #endif
720
+ f (ithr, nthr);
721
+ #if defined(DNNL_ENABLE_ITT_TASKS)
722
+ if (mark_task && itt_enable) itt::primitive_task_end ();
723
+ #endif
724
+ },
725
+ tbb::static_partitioner ());
726
+ #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB_AUTO
727
+ tbb::parallel_for (
728
+ 0 , nthr, [&](int ithr) { f (ithr, nthr); });
729
+ #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
730
+ using namespace dnnl ::impl::threadpool_utils;
731
+ dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool ();
732
+ if (!tp || dnnl_in_parallel ()) {
733
+ threadpool_utils::deactivate_threadpool ();
734
+ for (int ithr = 0 ; ithr < nthr; ithr++) {
735
+ f (ithr, nthr);
736
+ }
737
+ threadpool_utils::activate_threadpool (tp);
738
+ } else {
739
+ bool async = tp->get_flags ()
740
+ & dnnl::threadpool_interop::threadpool_iface::ASYNCHRONOUS;
741
+ counting_barrier_t b;
742
+ if (async) b.init (nthr);
743
+ tp->parallel_for (nthr, [&, tp](int ithr, int nthr) {
744
+ bool is_master = threadpool_utils::get_active_threadpool () == tp;
745
+ if (!is_master) {
746
+ threadpool_utils::activate_threadpool (tp);
747
+ #if defined(DNNL_ENABLE_ITT_TASKS)
748
+ if (itt_enable) itt::primitive_task_start (task_primitive_kind);
749
+ #endif
750
+ }
751
+ f (ithr, nthr);
752
+ if (!is_master) {
753
+ #if defined(DNNL_ENABLE_ITT_TASKS)
754
+ if (itt_enable) itt::primitive_task_end ();
755
+ #endif
756
+ threadpool_utils::deactivate_threadpool ();
757
+ }
758
+ if (async) b.notify ();
759
+ });
760
+ if (async) b.wait ();
761
+ }
762
+ #endif
763
+ #endif
764
+ }
765
+
766
+ template <typename T0, typename F>
767
+ void for_nd_legacy (const int ithr, const int nthr, const T0 &D0, F f) {
768
+ T0 start {0 }, end {0 };
769
+ balance211 (D0, nthr, ithr, start, end);
770
+ for (T0 d0 = start; d0 < end; ++d0)
771
+ f (d0);
772
+ }
773
+
774
+ template <typename T0, typename T1, typename T2, typename T3, typename F>
775
+ void for_nd_legacy (const int ithr, const int nthr, const T0 &D0, const T1 &D1,
776
+ const T2 &D2, const T3 &D3, F f) {
777
+ const size_t work_amount = (size_t )D0 * D1 * D2 * D3;
778
+ if (work_amount == 0 ) return ;
779
+ size_t start {0 }, end {0 };
780
+ balance211 (work_amount, nthr, ithr, start, end);
781
+
782
+ T0 d0 {0 };
783
+ T1 d1 {0 };
784
+ T2 d2 {0 };
785
+ T3 d3 {0 };
786
+ utils::nd_iterator_init (start, d0, D0, d1, D1, d2, D2, d3, D3);
787
+ for (size_t iwork = start; iwork < end; ++iwork) {
788
+ f (d0, d1, d2, d3);
789
+ utils::nd_iterator_step (d0, D0, d1, D1, d2, D2, d3, D3);
790
+ }
791
+ }
792
+
793
+ template <typename T0, typename T1, typename T2, typename T3, typename T4,
794
+ typename T5, typename F>
795
+ void for_nd_legacy (const int ithr, const int nthr, const T0 &D0, const T1 &D1,
796
+ const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
797
+ const size_t work_amount = (size_t )D0 * D1 * D2 * D3 * D4 * D5;
798
+ if (work_amount == 0 ) return ;
799
+ size_t start {0 }, end {0 };
800
+ balance211 (work_amount, nthr, ithr, start, end);
801
+
802
+ T0 d0 {0 };
803
+ T1 d1 {0 };
804
+ T2 d2 {0 };
805
+ T3 d3 {0 };
806
+ T4 d4 {0 };
807
+ T5 d5 {0 };
808
+ utils::nd_iterator_init (
809
+ start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
810
+ for (size_t iwork = start; iwork < end; ++iwork) {
811
+ f (d0, d1, d2, d3, d4, d5);
812
+ utils::nd_iterator_step (d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
813
+ }
814
+ }
815
+
816
+ template <typename T0, typename F>
817
+ void parallel_nd_legacy (const T0 &D0, F f) {
818
+ const size_t work_amount = (size_t )D0;
819
+ int nthr = adjust_num_threads (dnnl_get_current_num_threads (), work_amount);
820
+ if (nthr)
821
+ parallel_legacy (nthr, [&](int ithr, int nthr) { for_nd_legacy (ithr, nthr, D0, f); });
822
+ }
823
+
824
+ template <typename T0, typename T1, typename T2, typename T3, typename F>
825
+ void parallel_nd_legacy (const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
826
+ const size_t work_amount = (size_t )D0 * D1 * D2 * D3;
827
+ int nthr = adjust_num_threads (dnnl_get_current_num_threads (), work_amount);
828
+ if (nthr)
829
+ parallel_legacy (nthr, [&](int ithr, int nthr) {
830
+ for_nd_legacy (ithr, nthr, D0, D1, D2, D3, f);
831
+ });
832
+ }
833
+
834
+ template <typename T0, typename T1, typename T2, typename T3, typename T4,
835
+ typename T5, typename F>
836
+ void parallel_nd_legacy (const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
837
+ const T4 &D4, const T5 &D5, F f) {
838
+ const size_t work_amount = (size_t )D0 * D1 * D2 * D3 * D4 * D5;
839
+ int nthr = adjust_num_threads (dnnl_get_current_num_threads (), work_amount);
840
+ if (nthr)
841
+ parallel_legacy (nthr, [&](int ithr, int nthr) {
842
+ for_nd_legacy (ithr, nthr, D0, D1, D2, D3, D4, D5, f);
843
+ });
844
+ }
845
+
677
846
} // namespace impl
678
847
} // namespace dnnl
679
848
0 commit comments