diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/experimental_detectron_generate_proposals_single_image_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/experimental_detectron_generate_proposals_single_image_ref.cl index aa9cd2d3c387e6..5ea7c3be62e0df 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/experimental_detectron_generate_proposals_single_image_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/experimental_detectron_generate_proposals_single_image_ref.cl @@ -102,7 +102,17 @@ inline void FUNC(swap_box)(__global Box* a, __global Box* b) { } inline int FUNC(partition)(__global Box* arr, int l, int h) { - INPUT0_TYPE pivotScore = arr[h].score; + static int static_counter = 0; + static_counter++; + int pivot_idx = l; + if (static_counter%3 == 0) { //cyclic pivot selection rotation + pivot_idx = (l+h)/2; + } + if (static_counter%3 == 1) { + pivot_idx = h; + } + INPUT0_TYPE pivotScore = arr[pivot_idx].score; + FUNC_CALL(swap_box)(&arr[h], &arr[pivot_idx]); int i = (l - 1); for (int j = l; j <= h - 1; j++) { if (arr[j].score > pivotScore) { @@ -129,7 +139,7 @@ inline void FUNC(bubbleSortIterative)(__global Box* arr, int l, int h) { } } -inline void FUNC(quickSortIterative)(__global Box* arr, int l, int h) { +inline void FUNC(quickSelectIterative)(__global Box* arr, int l, int h) { // Create an auxiliary stack const int kStackSize = 100; int stack[kStackSize]; @@ -153,7 +163,7 @@ inline void FUNC(quickSortIterative)(__global Box* arr, int l, int h) { // If there are elements on left side of pivot, // then push left side to stack - if (p - 1 > l) { + if (p - 1 > l && l < PRE_NMS_TOPN) { if (top >= (kStackSize - 1)) { FUNC_CALL(bubbleSortIterative)(arr, l, p - 1); } else { @@ -164,7 +174,7 @@ inline void FUNC(quickSortIterative)(__global Box* arr, int l, int h) { // If there are elements on right side of pivot, // then push right side to stack - if (p + 1 < h) { + if (p + 1 < h && p + 1 < PRE_NMS_TOPN) { if (top >= (kStackSize - 1)) { FUNC_CALL(bubbleSortIterative)(arr, p + 1, h); } else { @@ -179,7 +189,7 @@ inline void FUNC(quickSortIterative)(__global Box* arr, int l, int h) { KERNEL(edgpsi_ref_stage_1)(__global OUTPUT_TYPE* proposals) { __global Box* boxes = (__global Box*)proposals; - FUNC_CALL(quickSortIterative)(boxes, 0, NUM_PROPOSALS-1); + FUNC_CALL(quickSelectIterative)(boxes, 0, NUM_PROPOSALS-1); } #undef Box #endif /* EDGPSI_STAGE_1 */