Skip to content

Commit dc44f47

Browse files
authored
[gpu]: GridSample optimization (#28670)
Optimized Grid Sample kernel for bilinear and zero padding case only(for now). ### Details: - opt version supports only bfyz data layout - opt version achieves up to ~80% peak mem bandwidth on A770 on big enough inputs - opt version is 20 - 40x faster than ref impl in client model referred in this PR. ### Tickets: - *CVS-161002*
1 parent 90dc653 commit dc44f47

8 files changed

+372
-109
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPdx_1-License-Identifier: Apache-2.0
3+
//
4+
5+
typedef INPUT0_TYPE data_t;
6+
typedef INPUT1_TYPE grid_t;
7+
typedef OUTPUT_TYPE output_t;
8+
9+
typedef INPUT0_TYPE data_et;
10+
typedef float grid_et;
11+
typedef OUTPUT_TYPE output_et;
12+
13+
#if defined(ALIGN_CORNERS)
14+
# define rescale_align FUNC(denormalize)
15+
inline grid_et rescale_align(const grid_et value, const size_t range) {
16+
return (value + 1) * ((grid_et)(range)-1) / 2;
17+
}
18+
#else
19+
# define rescale_noalign FUNC(denormalize)
20+
inline grid_et rescale_noalign(const grid_et value, const size_t range) {
21+
return ((value + 1) * (grid_et)(range)-1) / 2;
22+
}
23+
#endif
24+
#define denormalize FUNC_CALL(denormalize)
25+
26+
inline const bool FUNC(is_between)(int val, int min, int max) {
27+
return (val >= min) && (val < max);
28+
}
29+
#define is_between FUNC_CALL(is_between)
30+
31+
#define PRE_CALC_VALID_OFFSETS_FOR_INPUT_LOAD(x_n, x_y, GLOBAL_OFFSET) \
32+
const grid_et y_d = denormalize(y_n, INPUT0_SIZE_Y); \
33+
const grid_et x_d = denormalize(x_n, INPUT0_SIZE_X); \
34+
const int y_topleft = (int)floor(y_d); \
35+
const int x_topleft = (int)floor(x_d); \
36+
const grid_et dy = y_d - y_topleft; \
37+
const grid_et dx = x_d - x_topleft; \
38+
\
39+
const bool y_topleft_valid = is_between(y_topleft, 0, INPUT0_SIZE_Y); \
40+
const bool y_topleft_plus_valid = is_between(y_topleft + 1, 0, INPUT0_SIZE_Y); \
41+
const bool x_topleft_valid = is_between(x_topleft, 0, INPUT0_SIZE_X); \
42+
const bool x_topleft_plus_valid = is_between(x_topleft + 1, 0, INPUT0_SIZE_X); \
43+
\
44+
const bool v00_valid = y_topleft_valid && x_topleft_valid; \
45+
const bool v01_valid = y_topleft_valid && x_topleft_plus_valid; \
46+
const bool v10_valid = y_topleft_plus_valid && x_topleft_valid; \
47+
const bool v11_valid = y_topleft_plus_valid && x_topleft_plus_valid; \
48+
\
49+
const int v00_OFFSET = v00_valid ? (GLOBAL_OFFSET + y_topleft * INPUT0_SIZE_X + x_topleft) : 0; \
50+
const int v01_OFFSET = v01_valid ? (GLOBAL_OFFSET + y_topleft * INPUT0_SIZE_X + x_topleft + 1) : 0; \
51+
const int v10_OFFSET = v10_valid ? (GLOBAL_OFFSET + (y_topleft + 1) * INPUT0_SIZE_X + x_topleft) : 0; \
52+
const int v11_OFFSET = v11_valid ? (GLOBAL_OFFSET + (y_topleft + 1) * INPUT0_SIZE_X + x_topleft + 1) : 0;
53+
54+
// WARNING: This loads may read from 'wrong' location
55+
// (in sense that is has nothing to do with
56+
// sampling point being calculated) - this is done
57+
// intentianally to keep warp without need to sync
58+
// and allows for having multiple such loads on the fly - if
59+
// compiler is smart enough.
60+
// Otherwise, if load is done conditionally, software pipelinging
61+
// is hindered by having warp sync due to warp divergence.
62+
// Tested on a770 GPU with ocl 3.0
63+
#define LOAD_INPUT(c, C_STRIDE) \
64+
const data_et v00_d = data[v00_OFFSET + c * C_STRIDE]; \
65+
const data_et v01_d = data[v01_OFFSET + c * C_STRIDE]; \
66+
const data_et v10_d = data[v10_OFFSET + c * C_STRIDE]; \
67+
const data_et v11_d = data[v11_OFFSET + c * C_STRIDE];
68+
69+
#define INTERPOLATE() \
70+
const data_et v00 = v00_valid ? v00_d * (1 - dx) : 0; \
71+
const data_et v01 = v01_valid ? v01_d * dx : 0; \
72+
const data_et v10 = v10_valid ? v10_d * (1 - dx) : 0; \
73+
const data_et v11 = v11_valid ? v11_d * dx : 0; \
74+
\
75+
const data_et q0 = v00 + v01; \
76+
const data_et q1 = v10 + v11; \
77+
const data_et out = dy * q1 + (1 - dy) * q0;
78+
79+
#define STORE(c, GLOBAL_OFFSET, C_STRIDE) output[GLOBAL_OFFSET + c * C_STRIDE] = out;
80+
81+
// ====================================================================
82+
//
83+
// GRID SAMPLE KERNEL
84+
//
85+
// ====================================================================
86+
87+
KERNEL(grid_sample_opt_bilinear_zeros)(const __global data_t* restrict data,
88+
const __global grid_t* restrict grid,
89+
__global output_t* restrict output) {
90+
#if !defined(INTERPOLATION_MODE_BILINEAR)
91+
# error[clDNN grid_sample_opt_bilinear.cl]: This kernel only support bilinear interppolation mode.
92+
#endif
93+
94+
#if !defined(PADDING_MODE_ZEROS)
95+
# error[clDNN grid_sample_opt_bilinear.cl]: This kernel only support zeros padding mode.
96+
#endif
97+
98+
const int n = get_global_id(0);
99+
100+
const int LOCAL_GRID_OFFSET_FOR_THI_BLOCK = GRID_ITEMS_PER_BLOCK * 2 * get_group_id(1);
101+
const int OUTPUT_C_STRIDE = OUTPUT_SIZE_Y * OUTPUT_SIZE_X;
102+
const int GLOBAL_GRID_OFFSET_FOR_THIS_BLOCK = n * OUTPUT_C_STRIDE * 2 + LOCAL_GRID_OFFSET_FOR_THI_BLOCK;
103+
const int BLOCK_SIZE = get_local_size(1);
104+
const grid_t* restrict grid_for_this_block = grid + GLOBAL_GRID_OFFSET_FOR_THIS_BLOCK;
105+
const int GRID_ITEMS_FOR_THIS_BLOCK =
106+
min(OUTPUT_C_STRIDE * 2 - LOCAL_GRID_OFFSET_FOR_THI_BLOCK, GRID_ITEMS_PER_BLOCK * 2);
107+
108+
const int INPUT_C_STRIDE = INPUT0_SIZE_Y * INPUT0_SIZE_X;
109+
const int GLOBAL_INPUT_OFFSET_THIS_THREAD = n * INPUT0_FEATURE_NUM * INPUT_C_STRIDE;
110+
111+
// The basic idea is to cache and reuse grid vals for getting close to
112+
// optimal numer of loads(and stores).
113+
for (int thisThreadHW = get_local_linear_id() * 2; thisThreadHW < GRID_ITEMS_FOR_THIS_BLOCK;
114+
thisThreadHW += 2 * BLOCK_SIZE) {
115+
const int globalThisThreadHW = (thisThreadHW + LOCAL_GRID_OFFSET_FOR_THI_BLOCK) / 2;
116+
const int h = globalThisThreadHW / OUTPUT_SIZE_X;
117+
const int w = globalThisThreadHW % OUTPUT_SIZE_X;
118+
const int GLOBAL_OUTPUT_OFFSET_THIS_THREAD =
119+
n * OUTPUT_FEATURE_NUM * OUTPUT_SIZE_Y * OUTPUT_SIZE_X + h * OUTPUT_SIZE_X + w;
120+
121+
const grid_et x_n = grid_for_this_block[thisThreadHW];
122+
const grid_et y_n = grid_for_this_block[thisThreadHW + 1];
123+
124+
PRE_CALC_VALID_OFFSETS_FOR_INPUT_LOAD(x_n, y_n, GLOBAL_INPUT_OFFSET_THIS_THREAD);
125+
126+
#pragma unroll
127+
for (int c = 0; c < OUTPUT_FEATURE_NUM; ++c) {
128+
LOAD_INPUT(c, INPUT_C_STRIDE);
129+
INTERPOLATE();
130+
STORE(c, GLOBAL_OUTPUT_OFFSET_THIS_THREAD, OUTPUT_C_STRIDE);
131+
}
132+
}
133+
}
134+
135+
#undef denormalize
136+
#undef STORE
137+
#undef INTERPOLATE
138+
#undef PRE_CALC_VALID_OFFSETS_FOR_INPUT_LOAD
139+
#undef LOAD_INPUT
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "grid_sample_kernel_base.hpp"
6+
7+
#include "kernel_selector_utils.h"
8+
9+
namespace kernel_selector {
10+
11+
KernelsData GridSampleKernelBase::GetKernelsData(const Params& params) const {
12+
if (!Validate(params)) {
13+
return {};
14+
}
15+
16+
auto kernel_data = KernelData::Default<grid_sample_params>(params);
17+
const auto& kernel_params = dynamic_cast<const grid_sample_params&>(*kernel_data.params);
18+
const auto dispatch_data = CalcDispatch(kernel_params);
19+
const auto entry_point = GetEntryPoint(kernelName, kernel_params.layerID, params);
20+
const auto jit_constants = GetJitConstants(kernel_params);
21+
const auto jit = CreateJit(kernelName, jit_constants, entry_point);
22+
auto& kernel = kernel_data.kernels.front();
23+
24+
FillCLKernelData(kernel, dispatch_data, params.engineInfo, kernelName, jit, entry_point, {}, false, false, 2);
25+
26+
return {kernel_data};
27+
}
28+
29+
bool GridSampleKernelBase::Validate(const Params& params) const {
30+
if (params.GetType() != KernelType::GRID_SAMPLE) {
31+
return false;
32+
}
33+
34+
const auto& kernel_params = dynamic_cast<const grid_sample_params&>(params);
35+
if (kernel_params.inputs.size() != 2) {
36+
return false;
37+
}
38+
39+
return true;
40+
}
41+
42+
JitConstants GridSampleKernelBase::GetJitConstants(const grid_sample_params& kernel_params) const {
43+
auto jit_constants = MakeBaseParamsJitConstants(kernel_params);
44+
45+
jit_constants.AddConstants({
46+
MakeJitConstant("INTERPOLATION_MODE_" + ov::as_string(kernel_params.interpolation_mode), true),
47+
MakeJitConstant("PADDING_MODE_" + ov::as_string(kernel_params.padding_mode), true),
48+
});
49+
50+
if (kernel_params.align_corners) {
51+
jit_constants.AddConstant(MakeJitConstant("ALIGN_CORNERS", true));
52+
}
53+
54+
return jit_constants;
55+
}
56+
57+
} // namespace kernel_selector
58+
59+
namespace ov {
60+
61+
template <>
62+
ov::EnumNames<kernel_selector::grid_sample_params::InterpolationMode>& ::ov::EnumNames<
63+
kernel_selector::grid_sample_params::InterpolationMode>::get() {
64+
static auto enum_names = EnumNames<kernel_selector::grid_sample_params::InterpolationMode>(
65+
"kernel_selector::grid_sample_params::InterpolationMode",
66+
{
67+
{"BILINEAR", kernel_selector::grid_sample_params::InterpolationMode::BILINEAR},
68+
{"BICUBIC", kernel_selector::grid_sample_params::InterpolationMode::BICUBIC},
69+
{"NEAREST", kernel_selector::grid_sample_params::InterpolationMode::NEAREST},
70+
});
71+
return enum_names;
72+
}
73+
74+
template <>
75+
ov::EnumNames<kernel_selector::grid_sample_params::PaddingMode>& ::ov::EnumNames<
76+
kernel_selector::grid_sample_params::PaddingMode>::get() {
77+
static auto enum_names = EnumNames<kernel_selector::grid_sample_params::PaddingMode>(
78+
"kernel_selector::grid_sample_params::PaddingMode",
79+
{
80+
{"ZEROS", kernel_selector::grid_sample_params::PaddingMode::ZEROS},
81+
{"BORDER", kernel_selector::grid_sample_params::PaddingMode::BORDER},
82+
{"REFLECTION", kernel_selector::grid_sample_params::PaddingMode::REFLECTION},
83+
});
84+
return enum_names;
85+
}
86+
87+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "kernel_base_opencl.h"
8+
9+
namespace kernel_selector {
10+
11+
/**
12+
* GridSample reference kernel parameters.
13+
*/
14+
struct grid_sample_params : public base_params {
15+
grid_sample_params() : base_params(KernelType::GRID_SAMPLE) {}
16+
bool align_corners = false;
17+
enum class InterpolationMode {
18+
BILINEAR,
19+
BICUBIC,
20+
NEAREST,
21+
} interpolation_mode = InterpolationMode::BILINEAR;
22+
enum class PaddingMode {
23+
ZEROS,
24+
BORDER,
25+
REFLECTION,
26+
} padding_mode = PaddingMode::ZEROS;
27+
};
28+
29+
/**
30+
* GridSampleKernelBase.
31+
*/
32+
class GridSampleKernelBase : public KernelBaseOpenCL {
33+
public:
34+
using KernelBaseOpenCL::KernelBaseOpenCL;
35+
36+
KernelsData GetKernelsData(const Params& params) const override;
37+
38+
protected:
39+
virtual CommonDispatchData CalcDispatch(const grid_sample_params& kernel_params) const = 0;
40+
virtual JitConstants GetJitConstants(const grid_sample_params& kernel_params) const;
41+
42+
bool Validate(const Params& params) const override;
43+
};
44+
45+
} // namespace kernel_selector
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "grid_sample_kernel_opt_bilinear_zeros.hpp"
6+
7+
#include "kernel_selector_utils.h"
8+
9+
namespace kernel_selector {
10+
11+
constexpr size_t THREADS_PER_BLOCK = 256;
12+
constexpr size_t GRID_ITEMS_PER_BLOCK = THREADS_PER_BLOCK;
13+
14+
CommonDispatchData GridSampleKernelOpt_BilinearZeros::CalcDispatch(const grid_sample_params& kernel_params) const {
15+
CommonDispatchData dispatch_data;
16+
const auto& output = kernel_params.outputs.front();
17+
18+
auto blocks = (output.Y().v * output.X().v + GRID_ITEMS_PER_BLOCK - 1) / GRID_ITEMS_PER_BLOCK;
19+
20+
dispatch_data.gws = {output.Batch().v, blocks * THREADS_PER_BLOCK, 1};
21+
dispatch_data.lws = {1, THREADS_PER_BLOCK, 1};
22+
23+
return dispatch_data;
24+
}
25+
26+
KernelsPriority GridSampleKernelOpt_BilinearZeros::GetKernelsPriority(const Params& /*params*/) const {
27+
return FORCE_PRIORITY_8;
28+
}
29+
30+
bool GridSampleKernelOpt_BilinearZeros::Validate(const Params& params) const {
31+
if (!TBase::Validate(params))
32+
return false;
33+
34+
const auto& kernel_params = static_cast<const grid_sample_params&>(params);
35+
if (kernel_params.interpolation_mode != grid_sample_params::InterpolationMode::BILINEAR)
36+
return false;
37+
38+
if (kernel_params.padding_mode != grid_sample_params::PaddingMode::ZEROS)
39+
return false;
40+
41+
return true;
42+
}
43+
44+
JitConstants GridSampleKernelOpt_BilinearZeros::GetJitConstants(const grid_sample_params& kernel_params) const {
45+
auto jit_constants = TBase::GetJitConstants(kernel_params);
46+
47+
jit_constants.AddConstants({
48+
MakeJitConstant("GRID_ITEMS_PER_BLOCK", GRID_ITEMS_PER_BLOCK)
49+
});
50+
51+
return jit_constants;
52+
}
53+
54+
ParamsKey GridSampleKernelOpt_BilinearZeros::GetSupportedKey() const {
55+
ParamsKey key;
56+
key.EnableAllInputDataType();
57+
key.EnableAllOutputDataType();
58+
key.EnableDifferentTypes();
59+
key.EnableInputLayout(DataLayout::bfyx);
60+
key.EnableOutputLayout(DataLayout::bfyx);
61+
key.EnableTensorOffset();
62+
key.EnableTensorPitches();
63+
key.EnableBatching();
64+
return key;
65+
}
66+
67+
} // namespace kernel_selector
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "grid_sample_kernel_base.hpp"
8+
9+
namespace kernel_selector {
10+
11+
class GridSampleKernelOpt_BilinearZeros : public GridSampleKernelBase {
12+
public:
13+
using TBase = GridSampleKernelBase;
14+
GridSampleKernelOpt_BilinearZeros() : GridSampleKernelBase("grid_sample_opt_bilinear_zeros") {}
15+
16+
protected:
17+
ParamsKey GetSupportedKey() const override;
18+
CommonDispatchData CalcDispatch(const grid_sample_params& kernel_params) const override;
19+
KernelsPriority GetKernelsPriority(const Params& /*params*/) const override;
20+
bool Validate(const Params& params) const override;
21+
JitConstants GetJitConstants(const grid_sample_params& kernel_params) const override;
22+
};
23+
24+
} // namespace kernel_selector

0 commit comments

Comments
 (0)