Skip to content

Commit 78d309b

Browse files
antonvorluweizhou2016
authored andcommitted
[FIX] added some legacy parallel methods to fix perf issues
- gemm conv im2col() - simple concat
1 parent 9182156 commit 78d309b

4 files changed

+187
-8
lines changed

src/common/dnnl_thread.hpp

+169
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ inline int dnnl_get_current_num_threads() {
207207
#define simdlen(x)
208208
#endif // long simdlen if
209209

210+
#if defined(DNNL_ENABLE_ITT_TASKS)
211+
#include "common/ittnotify.hpp"
212+
#endif
213+
210214
namespace dnnl {
211215
namespace impl {
212216

@@ -674,6 +678,171 @@ void parallel_nd_in_omp(Args &&...args) {
674678
#endif
675679
}
676680

681+
template <typename F>
682+
void parallel_legacy(int nthr, F f) {
683+
nthr = adjust_num_threads(nthr, INT64_MAX);
684+
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ
685+
assert(nthr == 1);
686+
f(0, 1);
687+
#else
688+
#if defined(DNNL_ENABLE_ITT_TASKS)
689+
auto task_primitive_kind = itt::primitive_task_get_current_kind();
690+
bool itt_enable = itt::get_itt(itt::__itt_task_level_high);
691+
#endif
692+
if (nthr == 1) {
693+
f(0, 1);
694+
return;
695+
}
696+
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
697+
#pragma omp parallel num_threads(nthr)
698+
{
699+
int nthr_ = omp_get_num_threads();
700+
int ithr_ = omp_get_thread_num();
701+
assert(nthr_ == nthr);
702+
#if defined(DNNL_ENABLE_ITT_TASKS)
703+
if (ithr_ && itt_enable) itt::primitive_task_start(task_primitive_kind);
704+
#endif
705+
f(ithr_, nthr_);
706+
#if defined(DNNL_ENABLE_ITT_TASKS)
707+
if (ithr_ && itt_enable) itt::primitive_task_end();
708+
#endif
709+
}
710+
#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB
711+
tbb::parallel_for(
712+
0, nthr,
713+
[&](int ithr) {
714+
#if defined(DNNL_ENABLE_ITT_TASKS)
715+
bool mark_task = itt::primitive_task_get_current_kind()
716+
== primitive_kind::undefined;
717+
if (mark_task && itt_enable)
718+
itt::primitive_task_start(task_primitive_kind);
719+
#endif
720+
f(ithr, nthr);
721+
#if defined(DNNL_ENABLE_ITT_TASKS)
722+
if (mark_task && itt_enable) itt::primitive_task_end();
723+
#endif
724+
},
725+
tbb::static_partitioner());
726+
#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB_AUTO
727+
tbb::parallel_for(
728+
0, nthr, [&](int ithr) { f(ithr, nthr); });
729+
#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
730+
using namespace dnnl::impl::threadpool_utils;
731+
dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool();
732+
if (!tp || dnnl_in_parallel()) {
733+
threadpool_utils::deactivate_threadpool();
734+
for (int ithr = 0; ithr < nthr; ithr++) {
735+
f(ithr, nthr);
736+
}
737+
threadpool_utils::activate_threadpool(tp);
738+
} else {
739+
bool async = tp->get_flags()
740+
& dnnl::threadpool_interop::threadpool_iface::ASYNCHRONOUS;
741+
counting_barrier_t b;
742+
if (async) b.init(nthr);
743+
tp->parallel_for(nthr, [&, tp](int ithr, int nthr) {
744+
bool is_master = threadpool_utils::get_active_threadpool() == tp;
745+
if (!is_master) {
746+
threadpool_utils::activate_threadpool(tp);
747+
#if defined(DNNL_ENABLE_ITT_TASKS)
748+
if (itt_enable) itt::primitive_task_start(task_primitive_kind);
749+
#endif
750+
}
751+
f(ithr, nthr);
752+
if (!is_master) {
753+
#if defined(DNNL_ENABLE_ITT_TASKS)
754+
if (itt_enable) itt::primitive_task_end();
755+
#endif
756+
threadpool_utils::deactivate_threadpool();
757+
}
758+
if (async) b.notify();
759+
});
760+
if (async) b.wait();
761+
}
762+
#endif
763+
#endif
764+
}
765+
766+
template <typename T0, typename F>
767+
void for_nd_legacy(const int ithr, const int nthr, const T0 &D0, F f) {
768+
T0 start {0}, end {0};
769+
balance211(D0, nthr, ithr, start, end);
770+
for (T0 d0 = start; d0 < end; ++d0)
771+
f(d0);
772+
}
773+
774+
template <typename T0, typename T1, typename T2, typename T3, typename F>
775+
void for_nd_legacy(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
776+
const T2 &D2, const T3 &D3, F f) {
777+
const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
778+
if (work_amount == 0) return;
779+
size_t start {0}, end {0};
780+
balance211(work_amount, nthr, ithr, start, end);
781+
782+
T0 d0 {0};
783+
T1 d1 {0};
784+
T2 d2 {0};
785+
T3 d3 {0};
786+
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
787+
for (size_t iwork = start; iwork < end; ++iwork) {
788+
f(d0, d1, d2, d3);
789+
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
790+
}
791+
}
792+
793+
template <typename T0, typename T1, typename T2, typename T3, typename T4,
794+
typename T5, typename F>
795+
void for_nd_legacy(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
796+
const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
797+
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
798+
if (work_amount == 0) return;
799+
size_t start {0}, end {0};
800+
balance211(work_amount, nthr, ithr, start, end);
801+
802+
T0 d0 {0};
803+
T1 d1 {0};
804+
T2 d2 {0};
805+
T3 d3 {0};
806+
T4 d4 {0};
807+
T5 d5 {0};
808+
utils::nd_iterator_init(
809+
start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
810+
for (size_t iwork = start; iwork < end; ++iwork) {
811+
f(d0, d1, d2, d3, d4, d5);
812+
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
813+
}
814+
}
815+
816+
template <typename T0, typename F>
817+
void parallel_nd_legacy(const T0 &D0, F f) {
818+
const size_t work_amount = (size_t)D0;
819+
int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount);
820+
if (nthr)
821+
parallel_legacy(nthr, [&](int ithr, int nthr) { for_nd_legacy(ithr, nthr, D0, f); });
822+
}
823+
824+
template <typename T0, typename T1, typename T2, typename T3, typename F>
825+
void parallel_nd_legacy(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
826+
const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
827+
int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount);
828+
if (nthr)
829+
parallel_legacy(nthr, [&](int ithr, int nthr) {
830+
for_nd_legacy(ithr, nthr, D0, D1, D2, D3, f);
831+
});
832+
}
833+
834+
template <typename T0, typename T1, typename T2, typename T3, typename T4,
835+
typename T5, typename F>
836+
void parallel_nd_legacy(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
837+
const T4 &D4, const T5 &D5, F f) {
838+
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
839+
int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount);
840+
if (nthr)
841+
parallel_legacy(nthr, [&](int ithr, int nthr) {
842+
for_nd_legacy(ithr, nthr, D0, D1, D2, D3, D4, D5, f);
843+
});
844+
}
845+
677846
} // namespace impl
678847
} // namespace dnnl
679848

src/cpu/gemm_convolution_utils.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr,
454454
bool with_input_zp = input_zp != nullptr;
455455

456456
if (sd == 1 && sh == 1 && sw == 1 && dd == 1 && dh == 1 && dw == 1)
457-
parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
457+
parallel_nd_legacy(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
458458
[&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) {
459459
col_dt *__restrict col_loc = col + kd * col_kd_s
460460
+ kh * col_kh_s + kw * col_kw_s + ic * col_ic_s;
@@ -484,7 +484,7 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr,
484484
}
485485
});
486486
else if (sd == 2 && sh == 2 && sw == 2 && dd == 1 && dh == 1 && dw == 1)
487-
parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
487+
parallel_nd_legacy(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
488488
[&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) {
489489
col_dt *__restrict col_loc = col + kd * col_kd_s
490490
+ kh * col_kh_s + kw * col_kw_s + ic * col_ic_s;
@@ -516,7 +516,7 @@ void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr,
516516
}
517517
});
518518
else
519-
parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
519+
parallel_nd_legacy(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
520520
[&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) {
521521
col_dt *__restrict col_loc = col + kd * col_kd_s
522522
+ kh * col_kh_s + kw * col_kw_s + ic * col_ic_s;
@@ -660,7 +660,7 @@ void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im,
660660
// Generated code is more optimized for stride_w == 1
661661
// because innermost loop is by width
662662
if (sw == 1)
663-
parallel_nd(cb, jcp.kh, jcp.kw, oh_range,
663+
parallel_nd_legacy(cb, jcp.kh, jcp.kw, oh_range,
664664
[&](dim_t ic, dim_t kh, dim_t kw, dim_t ohr) {
665665
const dim_t oh = ohr + oh_begin;
666666
const dim_t ih = oh * sh - tp + kh * dh;
@@ -685,7 +685,7 @@ void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im,
685685
}
686686
});
687687
else
688-
parallel_nd(cb, jcp.kh, jcp.kw, oh_range,
688+
parallel_nd_legacy(cb, jcp.kh, jcp.kw, oh_range,
689689
[&](dim_t ic, dim_t kh, dim_t kw, dim_t ohr) {
690690
const dim_t oh = ohr + oh_begin;
691691
const dim_t ih = oh * sh - tp + kh * dh;
@@ -840,7 +840,7 @@ void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict _im,
840840
}
841841
}
842842
} else {
843-
parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb,
843+
parallel_nd_legacy(jcp.kh, jcp.kw, jcp.ic, hb,
844844
[&](dim_t kh, dim_t kw, dim_t ic, dim_t oh) {
845845
const dim_t hp = tp - kh * dh;
846846
const dim_t ih = (oh + hs) * sh - hp;

src/cpu/gemm_x8s8s32x_convolution.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr,
218218
balance211(work_amount, nthr, ithr, start, end);
219219
nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
220220
const uint8_t shift = jcp.signed_input ? 128 : 0;
221-
parallel_nd(jcp.im2col_sz, [&](ptrdiff_t i) { col[i] = shift; });
221+
parallel_nd_legacy(jcp.im2col_sz, [&](ptrdiff_t i) { col[i] = shift; });
222222

223223
status_t st = status::success;
224224

src/cpu/simple_concat.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,16 @@ status_t simple_concat_t<data_type>::execute(const exec_ctx_t &ctx) const {
7474
// Applies when concat axis is the outermost dimension, e.g. concat_axis = 0
7575
// or concat_axis = 1, and dims[0] = 1;
7676
if (!has_outer_loop) {
77+
// @todo CPU_PLUGIN:
78+
// the following implementation was used to fix some performace issues
79+
// Now after original oneDNN re-designed this piece it seems to be not applicable
80+
// anymore
81+
// for (int a = 0; a < num_arrs; ++a) {
82+
// const data_t *i = &iptrs[a][0];
83+
// data_t *o = &optrs[a][0];
84+
// parallel_nd_legacy(nelems_to_copy[a], [&](dim_t e) { o[e] = i[e]; });
85+
// }
86+
7787
int nthr = dnnl_get_max_threads();
7888
parallel(nthr, [&](int ithr, int nthr) {
7989
for (int a = 0; a < num_arrs; ++a) {
@@ -104,7 +114,7 @@ status_t simple_concat_t<data_type>::execute(const exec_ctx_t &ctx) const {
104114
const auto L1_size = platform::get_per_core_cache_size(1);
105115
UNUSED(L1_size); // for Windows
106116

107-
parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3],
117+
parallel_nd_legacy(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3],
108118
phys_dims[4], num_arrs,
109119
[&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, dim_t a) {
110120
// check if zero memory

0 commit comments

Comments
 (0)