Skip to content

Commit 9432b3d

Browse files
authored
[GPU] Bugfix reorder for byfx format (#25782)
+ Reorder returns OOR error while handling byfx from a fused permute parent ### Details: - *item1* - *...* ### Tickets: - CVS-147330 --------- Signed-off-by: Min, Byung-il <byungil.min@intel.com>
1 parent 606d909 commit 9432b3d

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/reorder_data.cl

+15-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,15 @@ KERNEL (reorder_data)(
2727
#endif
2828
)
2929
{
30+
#if INPUT0_LAYOUT_BYFX
31+
// GWS_FEATURE takes Y for byfx format
32+
const uint b = get_global_id(GWS_BATCH);
33+
const uint y = get_global_id(GWS_FEATURE);
34+
#else
3035
const uint b = get_global_id(GWS_BATCH);
3136
const uint f = get_global_id(GWS_FEATURE);
37+
#endif
38+
3239
#if INPUT0_DIMS == 2
3340
const uint y = 0;
3441
const uint x = 0;
@@ -37,8 +44,14 @@ KERNEL (reorder_data)(
3744
const uint u = 0;
3845
const uint v = 0;
3946
#elif INPUT0_DIMS == 4
40-
const uint y = ((uint)(get_global_id(GWS_YX))) / INPUT0_SIZE_X;
41-
const uint x = ((uint)(get_global_id(GWS_YX))) % INPUT0_SIZE_X;
47+
#if INPUT0_LAYOUT_BYFX
48+
// GWS_YX takes (F and X) axes for byfx format
49+
const uint f = ((uint)(get_global_id(GWS_YX))) / INPUT0_SIZE_X;
50+
const uint x = ((uint)(get_global_id(GWS_YX))) % INPUT0_SIZE_X;
51+
#else
52+
const uint y = ((uint)(get_global_id(GWS_YX))) / INPUT0_SIZE_X;
53+
const uint x = ((uint)(get_global_id(GWS_YX))) % INPUT0_SIZE_X;
54+
#endif
4255
const uint z = 0;
4356
const uint w = 0;
4457
const uint u = 0;

src/plugins/intel_gpu/src/kernel_selector/kernels/reorder/reorder_kernel_base.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ ReorderKernelBase::DispatchData ReorderKernelBase::SetDefault(const reorder_para
189189
dispatchData.lws[0] = 1;
190190
dispatchData.lws[1] = 16;
191191
dispatchData.lws[2] = 1;
192+
} else if (input_l == DataLayout::byfx) {
193+
auto first_primary_axis_size = dispatchData.gws[0]; // X axis
194+
auto second_primiary_axis_size = dispatchData.gws[1]; // YF axes
195+
dispatchData.gws[0] = first_primary_axis_size * input.Feature().v; // takes XF axes
196+
dispatchData.gws[1] = second_primiary_axis_size / input.Feature().v; // takes Y axis
197+
dispatchData.lws = {1, 1, 1};
192198
}
193199

194200
return dispatchData;

src/plugins/intel_gpu/tests/unit/fusions/gemm_fusion_test.cpp

+38
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ class GemmFusingTest : public ::BaseFusingTest<gemm_test_params> {
142142
#define CASE_GEMM_PERMUTES_FUSION_FP16_3 { { 17, 11, 2, 18 }, { 17, 11, 18, 4 } }, { 17, 11, 2, 4 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
143143
#define CASE_GEMM_PERMUTES_FUSION_FP16_4 { { 3, 2, 10, 12 }, { 3, 2, 12, 20 } }, { 3, 2, 10, 20 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
144144
#define CASE_GEMM_PERMUTES_FUSION_FP16_5 { { 3, 2, 16, 32 }, { 3, 2, 32, 16} }, { 3, 2, 16, 16 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
145+
#define CASE_GEMM_PERMUTES_FUSION_FP16_6 { { 3, 2, 16, 32 }, { 3, 16, 2, 32} }, { 3, 2, 2, 32 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
146+
145147
class gemm_3in_quantize_i8 : public GemmFusingTest {};
146148
TEST_P(gemm_3in_quantize_i8, basic) {
147149
// TODO: Fix me, refer PR(#15873)
@@ -757,4 +759,40 @@ INSTANTIATE_TEST_SUITE_P(
757759
gemm_test_params{CASE_GEMM_PERMUTES_FUSION_FP16_3, 3, 6, "", broadcast_kinds::feature/*dummy*/, eltwise_mode::sum/*dummy*/, {{0, 2, 1, 3} /*byfx*/, {1, 2, 3, 0} /*xbfy*/, {0, 2, 1, 3} /*byfx*/}},
758760
}));
759761

762+
class permute_gemm_reorder : public GemmFusingTestOneDNN {};
763+
TEST_P(permute_gemm_reorder, fused_permute_gemm_with_reorder) {
764+
auto p = GetParam();
765+
auto in_lay0 = get_input_layout(p, 0);
766+
auto in_lay1 = get_input_layout(p, 1);
767+
auto permute_in_lay0 = get_permute_input_shape(in_lay0.get_shape(), p.permute_orders[0]);
768+
auto permute_in_lay1 = get_permute_input_shape(in_lay1.get_shape(), p.permute_orders[1]);
769+
in_lay0.set_partial_shape(permute_in_lay0);
770+
in_lay1.set_partial_shape(permute_in_lay1);
771+
create_topologies(
772+
input_layout("input0", in_lay0),
773+
input_layout("input1", in_lay1),
774+
permute("permute0", input_info("input0"), p.permute_orders[0]),
775+
reorder("reorder_permute", input_info("permute0"), p.default_format, data_types::f32),
776+
permute("permute1", input_info("input1"), p.permute_orders[1]),
777+
gemm("gemm_prim", { input_info("permute0"), input_info("permute1") }, data_types::f16),
778+
reorder("reorder_bfyx", input_info("gemm_prim"), p.default_format, data_types::f32),
779+
eltwise("eltwise", { input_info("reorder_permute"), input_info("reorder_bfyx") }, eltwise_mode::sum, data_types::f32)
780+
);
781+
782+
tolerance = default_tolerance(data_types::f16);
783+
execute(p, false);
784+
}
785+
786+
#define CASE_PERMUTES_GEMM_FUSION_FP16_1 { { 1, 12, 20, 64 }, { 1, 12, 64, 64 } }, { 1, 12, 20, 64 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
787+
#define CASE_PERMUTES_GEMM_FUSION_FP16_2 { { 3, 2, 10, 12 }, { 3, 2, 12, 1 } }, { 3, 2, 10, 1 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
788+
789+
INSTANTIATE_TEST_SUITE_P(
790+
fusings_gpu, permute_gemm_reorder, ::testing::ValuesIn(std::vector<gemm_test_params>{
791+
gemm_test_params{CASE_PERMUTES_GEMM_FUSION_FP16_1, 4, 6, "", broadcast_kinds::feature/*dummy*/, eltwise_mode::sum/*dummy*/, {{0, 2, 1, 3} /*byfx*/, {0, 2, 1, 3} /*byfx*/}},
792+
gemm_test_params{CASE_PERMUTES_GEMM_FUSION_FP16_1, 4, 6, "", broadcast_kinds::feature/*dummy*/, eltwise_mode::sum/*dummy*/, {{0, 2, 1, 3} /*byfx*/, {1, 2, 3, 0} /*xbfy*/}},
793+
gemm_test_params{CASE_PERMUTES_GEMM_FUSION_FP16_2, 4, 6, "", broadcast_kinds::feature/*dummy*/, eltwise_mode::sum/*dummy*/, {{0, 2, 1, 3} /*byfx*/, {0, 2, 1, 3} /*byfx*/}},
794+
gemm_test_params{CASE_PERMUTES_GEMM_FUSION_FP16_2, 4, 6, "", broadcast_kinds::feature/*dummy*/, eltwise_mode::sum/*dummy*/, {{0, 2, 1, 3} /*byfx*/, {1, 2, 3, 0} /*xbfy*/}},
795+
}));
796+
797+
760798
#endif // ENABLE_ONEDNN_FOR_GPU

0 commit comments

Comments
 (0)