Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] prevent too long sort in experimental detectron generate proposals single image #28422

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,17 @@ inline void FUNC(swap_box)(__global Box* a, __global Box* b) {
}

inline int FUNC(partition)(__global Box* arr, int l, int h) {
INPUT0_TYPE pivotScore = arr[h].score;
static int static_counter = 0;
static_counter++;
dnkurek marked this conversation as resolved.
Show resolved Hide resolved
int pivot_idx = l;
if (static_counter%3 == 0) { //cyclic pivot selection rotation
pivot_idx = (l+h)/2;
}
if (static_counter%3 == 1) {
pivot_idx = h;
}
INPUT0_TYPE pivotScore = arr[pivot_idx].score;
FUNC_CALL(swap_box)(&arr[h], &arr[pivot_idx]);
int i = (l - 1);
for (int j = l; j <= h - 1; j++) {
if (arr[j].score > pivotScore) {
Expand All @@ -129,7 +139,7 @@ inline void FUNC(bubbleSortIterative)(__global Box* arr, int l, int h) {
}
}

inline void FUNC(quickSortIterative)(__global Box* arr, int l, int h) {
inline void FUNC(quickSelectIterative)(__global Box* arr, int l, int h) {
// Create an auxiliary stack
const int kStackSize = 100;
int stack[kStackSize];
Expand All @@ -153,7 +163,7 @@ inline void FUNC(quickSortIterative)(__global Box* arr, int l, int h) {

// If there are elements on left side of pivot,
// then push left side to stack
if (p - 1 > l) {
if (p - 1 > l && l < PRE_NMS_TOPN) {
if (top >= (kStackSize - 1)) {
FUNC_CALL(bubbleSortIterative)(arr, l, p - 1);
} else {
Expand All @@ -164,7 +174,7 @@ inline void FUNC(quickSortIterative)(__global Box* arr, int l, int h) {

// If there are elements on right side of pivot,
// then push right side to stack
if (p + 1 < h) {
if (p + 1 < h && p + 1 < PRE_NMS_TOPN) {
if (top >= (kStackSize - 1)) {
FUNC_CALL(bubbleSortIterative)(arr, p + 1, h);
} else {
Expand All @@ -179,7 +189,7 @@ inline void FUNC(quickSortIterative)(__global Box* arr, int l, int h) {
KERNEL(edgpsi_ref_stage_1)(__global OUTPUT_TYPE* proposals) {
dnkurek marked this conversation as resolved.
Show resolved Hide resolved
__global Box* boxes = (__global Box*)proposals;

FUNC_CALL(quickSortIterative)(boxes, 0, NUM_PROPOSALS-1);
FUNC_CALL(quickSelectIterative)(boxes, 0, NUM_PROPOSALS-1);
}
#undef Box
#endif /* EDGPSI_STAGE_1 */
Expand Down
Loading