Skip to content

Commit 42841f3

Browse files
[GPU] Fix gws of Resample onnx kernel
1 parent 67f2537 commit 42841f3

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

src/plugins/intel_gpu/src/kernel_selector/kernels/resample/resample_kernel_onnx.cpp

+12-9
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ DeviceFeaturesKey ResampleKernelOnnx::get_required_device_features_key(const Par
8484
return get_common_subgroups_device_features_key(params);
8585
}
8686

87+
static size_t get_vec_size(const resample_params &params) {
88+
if (params.inputs[0].GetLayout() == DataLayout::fs_b_yx_fsv32) {
89+
return 2;
90+
} else {
91+
return 1;
92+
}
93+
}
94+
8795
ResampleKernelBase::DispatchData ResampleKernelOnnx::SetDefault(const kernel_selector::resample_params& arg) const {
8896
DispatchData dispatchData;
8997
std::vector<std::vector<Tensor::DataChannelName>> dims_by_gws;
@@ -96,7 +104,7 @@ ResampleKernelBase::DispatchData ResampleKernelOnnx::SetDefault(const kernel_sel
96104
}
97105

98106
dispatchData.gws[0] = CeilDiv(out.X().v, opt_x_block_size) * out.Y().v * out.Z().v;
99-
dispatchData.gws[1] = Align(out.Feature().v, sub_group_size);
107+
dispatchData.gws[1] = Align(CeilDiv(out.Feature().v, get_vec_size(arg)), sub_group_size);
100108
dispatchData.gws[2] = arg.outputs[0].Batch().v;
101109

102110
dispatchData.lws[0] = 1;
@@ -151,14 +159,9 @@ JitConstants ResampleKernelOnnx::GetJitConstants(const resample_params& params)
151159
jit.AddConstant(MakeJitConstant("X_BLOCKS", CeilDiv(params.outputs[0].X().v, opt_x_block_size)));
152160
jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", sub_group_size));
153161

154-
size_t vec_size = 0;
155-
if (params.inputs[0].GetLayout() == DataLayout::fs_b_yx_fsv32) {
156-
vec_size = 2;
157-
jit.AddConstant(MakeJitConstant("FEATURE_SLICE_SIZE", 32));
158-
} else {
159-
vec_size = 1;
160-
jit.AddConstant(MakeJitConstant("FEATURE_SLICE_SIZE", 16));
161-
}
162+
size_t vec_size = get_vec_size(params);
163+
jit.AddConstant(MakeJitConstant("FEATURE_SLICE_SIZE", 16 * vec_size));
164+
162165
if (IsThreeSpatialResample(params))
163166
jit.AddConstant(MakeJitConstant("THREE_SPATIAL_RESAMPLE", ""));
164167

src/plugins/intel_gpu/tests/unit/test_cases/resample_gpu_test.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -2371,6 +2371,7 @@ INSTANTIATE_TEST_SUITE_P(resample_opt_smoke_linear_onnx_4d_simple,
23712371
{ data_types::f16, {1, 128, 13, 13}, {1, 128, 26, 26}, 1, resample::InterpolateOp::InterpolateMode::LINEAR_ONNX, 1, format::bs_fs_yx_bsv32_fsv32, format::bs_fs_yx_bsv32_fsv32, {}, {}},
23722372
{ data_types::f16, {1, 128, 13, 13}, {1, 128, 26, 26}, 1, resample::InterpolateOp::InterpolateMode::LINEAR_ONNX, 1, format::bs_fs_yx_bsv16_fsv16, format::bs_fs_yx_bsv16_fsv16, {}, {}},
23732373
{ data_types::f16, {1, 128, 13, 13}, {1, 128, 26, 26}, 1, resample::InterpolateOp::InterpolateMode::LINEAR_ONNX, 1, format::b_fs_yx_fsv16, format::b_fs_yx_fsv32, {}, {}},
2374+
{ data_types::f16, {2, 32, 14, 14}, {2, 32, 28, 28}, 1, resample::InterpolateOp::InterpolateMode::LINEAR_ONNX, 1, format::fs_b_yx_fsv32, format::fs_b_yx_fsv32, {}, {}},
23742375
}
23752376
));
23762377

0 commit comments

Comments
 (0)