Skip to content

Commit 10d508a

Browse files
antonvorxczhai
authored andcommitted
[FIX] added some legacy parallel methods to fix perf issues
- gemm conv im2col() - simple concat
1 parent ee7f37c commit 10d508a

4 files changed

+187
-8
lines changed

src/common/dnnl_thread.hpp

+169
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ inline int dnnl_get_current_num_threads() {
207207
#define simdlen(x)
208208
#endif // long simdlen if
209209

210+
#if defined(DNNL_ENABLE_ITT_TASKS)
211+
#include "common/ittnotify.hpp"
212+
#endif
213+
210214
namespace dnnl {
211215
namespace impl {
212216

@@ -655,6 +659,171 @@ static inline void parallel_nd(dim_t D0, dim_t D1, dim_t D2, dim_t D3, dim_t D4,
655659
});
656660
}
657661

662+
template <typename F>
663+
void parallel_legacy(int nthr, F f) {
664+
nthr = adjust_num_threads(nthr, INT64_MAX);
665+
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ
666+
assert(nthr == 1);
667+
f(0, 1);
668+
#else
669+
#if defined(DNNL_ENABLE_ITT_TASKS)
670+
auto task_primitive_kind = itt::primitive_task_get_current_kind();
671+
bool itt_enable = itt::get_itt(itt::__itt_task_level_high);
672+
#endif
673+
if (nthr == 1) {
674+
f(0, 1);
675+
return;
676+
}
677+
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
678+
#pragma omp parallel num_threads(nthr)
679+
{
680+
int nthr_ = omp_get_num_threads();
681+
int ithr_ = omp_get_thread_num();
682+
assert(nthr_ == nthr);
683+
#if defined(DNNL_ENABLE_ITT_TASKS)
684+
if (ithr_ && itt_enable) itt::primitive_task_start(task_primitive_kind);
685+
#endif
686+
f(ithr_, nthr_);
687+
#if defined(DNNL_ENABLE_ITT_TASKS)
688+
if (ithr_ && itt_enable) itt::primitive_task_end();
689+
#endif
690+
}
691+
#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB
692+
tbb::parallel_for(
693+
0, nthr,
694+
[&](int ithr) {
695+
#if defined(DNNL_ENABLE_ITT_TASKS)
696+
bool mark_task = itt::primitive_task_get_current_kind()
697+
== primitive_kind::undefined;
698+
if (mark_task && itt_enable)
699+
itt::primitive_task_start(task_primitive_kind);
700+
#endif
701+
f(ithr, nthr);
702+
#if defined(DNNL_ENABLE_ITT_TASKS)
703+
if (mark_task && itt_enable) itt::primitive_task_end();
704+
#endif
705+
},
706+
tbb::static_partitioner());
707+
#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB_AUTO
708+
tbb::parallel_for(
709+
0, nthr, [&](int ithr) { f(ithr, nthr); });
710+
#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
711+
using namespace dnnl::impl::threadpool_utils;
712+
dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool();
713+
if (!tp || dnnl_in_parallel()) {
714+
threadpool_utils::deactivate_threadpool();
715+
for (int ithr = 0; ithr < nthr; ithr++) {
716+
f(ithr, nthr);
717+
}
718+
threadpool_utils::activate_threadpool(tp);
719+
} else {
720+
bool async = tp->get_flags()
721+
& dnnl::threadpool_interop::threadpool_iface::ASYNCHRONOUS;
722+
counting_barrier_t b;
723+
if (async) b.init(nthr);
724+
tp->parallel_for(nthr, [&, tp](int ithr, int nthr) {
725+
bool is_master = threadpool_utils::get_active_threadpool() == tp;
726+
if (!is_master) {
727+
threadpool_utils::activate_threadpool(tp);
728+
#if defined(DNNL_ENABLE_ITT_TASKS)
729+
if (itt_enable) itt::primitive_task_start(task_primitive_kind);
730+
#endif
731+
}
732+
f(ithr, nthr);
733+
if (!is_master) {
734+
#if defined(DNNL_ENABLE_ITT_TASKS)
735+
if (itt_enable) itt::primitive_task_end();
736+
#endif
737+
threadpool_utils::deactivate_threadpool();
738+
}
739+
if (async) b.notify();
740+
});
741+
if (async) b.wait();
742+
}
743+
#endif
744+
#endif
745+
}
746+
747+
template <typename T0, typename F>
748+
void for_nd_legacy(const int ithr, const int nthr, const T0 &D0, F f) {
749+
T0 start {0}, end {0};
750+
balance211(D0, nthr, ithr, start, end);
751+
for (T0 d0 = start; d0 < end; ++d0)
752+
f(d0);
753+
}
754+
755+
template <typename T0, typename T1, typename T2, typename T3, typename F>
756+
void for_nd_legacy(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
757+
const T2 &D2, const T3 &D3, F f) {
758+
const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
759+
if (work_amount == 0) return;
760+
size_t start {0}, end {0};
761+
balance211(work_amount, nthr, ithr, start, end);
762+
763+
T0 d0 {0};
764+
T1 d1 {0};
765+
T2 d2 {0};
766+
T3 d3 {0};
767+
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
768+
for (size_t iwork = start; iwork < end; ++iwork) {
769+
f(d0, d1, d2, d3);
770+
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
771+
}
772+
}
773+
774+
template <typename T0, typename T1, typename T2, typename T3, typename T4,
775+
typename T5, typename F>
776+
void for_nd_legacy(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
777+
const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
778+
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
779+
if (work_amount == 0) return;
780+
size_t start {0}, end {0};
781+
balance211(work_amount, nthr, ithr, start, end);
782+
783+
T0 d0 {0};
784+
T1 d1 {0};
785+
T2 d2 {0};
786+
T3 d3 {0};
787+
T4 d4 {0};
788+
T5 d5 {0};
789+
utils::nd_iterator_init(
790+
start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
791+
for (size_t iwork = start; iwork < end; ++iwork) {
792+
f(d0, d1, d2, d3, d4, d5);
793+
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
794+
}
795+
}
796+
797+
template <typename T0, typename F>
798+
void parallel_nd_legacy(const T0 &D0, F f) {
799+
const size_t work_amount = (size_t)D0;
800+
int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount);
801+
if (nthr)
802+
parallel_legacy(nthr, [&](int ithr, int nthr) { for_nd_legacy(ithr, nthr, D0, f); });
803+
}
804+
805+
template <typename T0, typename T1, typename T2, typename T3, typename F>
806+
void parallel_nd_legacy(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
807+
const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
808+
int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount);
809+
if (nthr)
810+
parallel_legacy(nthr, [&](int ithr, int nthr) {
811+
for_nd_legacy(ithr, nthr, D0, D1, D2, D3, f);
812+
});
813+
}
814+
815+
template <typename T0, typename T1, typename T2, typename T3, typename T4,
816+
typename T5, typename F>
817+
void parallel_nd_legacy(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
818+
const T4 &D4, const T5 &D5, F f) {
819+
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
820+
int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount);
821+
if (nthr)
822+
parallel_legacy(nthr, [&](int ithr, int nthr) {
823+
for_nd_legacy(ithr, nthr, D0, D1, D2, D3, D4, D5, f);
824+
});
825+
}
826+
658827
} // namespace impl
659828
} // namespace dnnl
660829

src/cpu/gemm_convolution_utils.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr,
454454
bool with_input_zp = input_zp != nullptr;
455455

456456
if (sd == 1 && sh == 1 && sw == 1 && dd == 1 && dh == 1 && dw == 1)
457-
parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
457+
parallel_nd_legacy(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
458458
[&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) {
459459
col_dt *__restrict col_loc = col + kd * col_kd_s
460460
+ kh * col_kh_s + kw * col_kw_s + ic * col_ic_s;
@@ -484,7 +484,7 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr,
484484
}
485485
});
486486
else if (sd == 2 && sh == 2 && sw == 2 && dd == 1 && dh == 1 && dw == 1)
487-
parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
487+
parallel_nd_legacy(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
488488
[&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) {
489489
col_dt *__restrict col_loc = col + kd * col_kd_s
490490
+ kh * col_kh_s + kw * col_kw_s + ic * col_ic_s;
@@ -516,7 +516,7 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr,
516516
}
517517
});
518518
else
519-
parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
519+
parallel_nd_legacy(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
520520
[&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) {
521521
col_dt *__restrict col_loc = col + kd * col_kd_s
522522
+ kh * col_kh_s + kw * col_kw_s + ic * col_ic_s;
@@ -660,7 +660,7 @@ void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im,
660660
// Generated code is more optimized for stride_w == 1
661661
// because innermost loop is by width
662662
if (sw == 1)
663-
parallel_nd(cb, jcp.kh, jcp.kw, oh_range,
663+
parallel_nd_legacy(cb, jcp.kh, jcp.kw, oh_range,
664664
[&](dim_t ic, dim_t kh, dim_t kw, dim_t ohr) {
665665
const dim_t oh = ohr + oh_begin;
666666
const dim_t ih = oh * sh - tp + kh * dh;
@@ -685,7 +685,7 @@ void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im,
685685
}
686686
});
687687
else
688-
parallel_nd(cb, jcp.kh, jcp.kw, oh_range,
688+
parallel_nd_legacy(cb, jcp.kh, jcp.kw, oh_range,
689689
[&](dim_t ic, dim_t kh, dim_t kw, dim_t ohr) {
690690
const dim_t oh = ohr + oh_begin;
691691
const dim_t ih = oh * sh - tp + kh * dh;
@@ -840,7 +840,7 @@ void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict _im,
840840
}
841841
}
842842
} else {
843-
parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb,
843+
parallel_nd_legacy(jcp.kh, jcp.kw, jcp.ic, hb,
844844
[&](dim_t kh, dim_t kw, dim_t ic, dim_t oh) {
845845
const dim_t hp = tp - kh * dh;
846846
const dim_t ih = (oh + hs) * sh - hp;

src/cpu/gemm_x8s8s32x_convolution.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr,
220220
balance211(work_amount, nthr, ithr, start, end);
221221
nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
222222
const uint8_t shift = jcp.signed_input ? 128 : 0;
223-
parallel_nd(jcp.im2col_sz, [&](ptrdiff_t i) { col[i] = shift; });
223+
parallel_nd_legacy(jcp.im2col_sz, [&](ptrdiff_t i) { col[i] = shift; });
224224

225225
status_t st = status::success;
226226

src/cpu/simple_concat.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,16 @@ status_t simple_concat_t<data_type>::execute(const exec_ctx_t &ctx) const {
7474
// Applies when concat axis is the outermost dimension, e.g. concat_axis = 0
7575
// or concat_axis = 1, and dims[0] = 1;
7676
if (!has_outer_loop) {
77+
// @todo CPU_PLUGIN:
78+
// the following implementation was used to fix some performace issues
79+
// Now after original oneDNN re-designed this piece it seems to be not applicable
80+
// anymore
81+
// for (int a = 0; a < num_arrs; ++a) {
82+
// const data_t *i = &iptrs[a][0];
83+
// data_t *o = &optrs[a][0];
84+
// parallel_nd_legacy(nelems_to_copy[a], [&](dim_t e) { o[e] = i[e]; });
85+
// }
86+
7787
int nthr = dnnl_get_max_threads();
7888
parallel(nthr, [&](int ithr, int nthr) {
7989
for (int a = 0; a < num_arrs; ++a) {
@@ -104,7 +114,7 @@ status_t simple_concat_t<data_type>::execute(const exec_ctx_t &ctx) const {
104114
const auto L1_size = platform::get_per_core_cache_size(1);
105115
UNUSED(L1_size); // for Windows
106116

107-
parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3],
117+
parallel_nd_legacy(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3],
108118
phys_dims[4], num_arrs,
109119
[&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, dim_t a) {
110120
// check if zero memory

0 commit comments

Comments
 (0)