Skip to content

Commit 9aec1a7

Browse files
[GPU] Fix accuracy of gemm_tiled_opt kernel
1 parent f2bd4d3 commit 9aec1a7

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_tiled_opt.cl

+37-21
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,15 @@ KERNEL(gemm_tiled_opt)(
309309
else
310310
#endif // INDIRECT_INPUT1
311311
{
312-
#if N_IS_ALIGNED_4BYTE
313-
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
314-
#else
312+
// #if N_IS_ALIGNED_4BYTE
313+
// b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
314+
// #else
315+
// b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
316+
// #endif
317+
#if TILE_N_NOT_DIVISIBLE
315318
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
319+
#else
320+
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
316321
#endif
317322
b_ptr += input1_offset;
318323
}
@@ -395,11 +400,16 @@ KERNEL(gemm_tiled_opt)(
395400
#if INDIRECT_INPUT0
396401
uint a_idx = FUNC_CALL(get_input0_indirect_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, (y + dot_id), (k * TILE_K + sglid), beam_table);
397402
A_FLOATN a_read = input0[a_idx];
398-
#elif K_IS_ALIGNED_4BYTE
399-
A_FLOATN a_read = BLOCK_READ_A(a_ptr, 0);
400-
#else // K_IS_ALIGNED_4BYTE
403+
// #elif K_IS_ALIGNED_4BYTE
404+
// A_FLOATN a_read = BLOCK_READ_A(a_ptr, 0);
405+
// #else // K_IS_ALIGNED_4BYTE
406+
// A_FLOATN a_read = a_ptr[sglid];
407+
// #endif // K_IS_ALIGNED_4BYTE
408+
#elif TILE_K_NOT_DIVISIBLE
401409
A_FLOATN a_read = a_ptr[sglid];
402-
#endif // K_IS_ALIGNED_4BYTE
410+
#else // TILE_K_NOT_DIVISIBLE
411+
A_FLOATN a_read = BLOCK_READ_A(a_ptr, 0);
412+
#endif // TILE_K_NOT_DIVISIBLE
403413
#endif // IS_DYNAMIC
404414
a_ptr += input0_offset;
405415

@@ -617,11 +627,16 @@ KERNEL(gemm_tiled_opt)(
617627
else
618628
#endif
619629
{
620-
#if N_IS_ALIGNED_4BYTE
621-
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
622-
#else // N_IS_ALIGNED_4BYTE
630+
// #if N_IS_ALIGNED_4BYTE
631+
// b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
632+
// #else // N_IS_ALIGNED_4BYTE
633+
// b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
634+
// #endif // N_IS_ALIGNED_4BYTE
635+
#if TILE_N_NOT_DIVISIBLE
623636
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
624-
#endif // N_IS_ALIGNED_4BYTE
637+
#else // TILE_N_NOT_DIVISIBLE
638+
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
639+
#endif // TILE_N_NOT_DIVISIBLE
625640
b_ptr += input1_offset;
626641
}
627642
#elif TRANSPOSE_INPUT1 == TRANSPOSE_OTHER // TRANSPOSE_INPUT1 == 0
@@ -660,23 +675,24 @@ KERNEL(gemm_tiled_opt)(
660675
}
661676
#endif // TRANSPOSE_INPUT1 == TRANSPOSE_Y_LAST
662677

663-
#if !INDIRECT_INPUT0 && K_IS_ALIGNED_4BYTE && (TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST)
664-
a_ptr = input0 + FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, (K_FULL_ITERATIONS * TILE_K));
665-
#endif
678+
// #if !INDIRECT_INPUT0 && K_IS_ALIGNED_4BYTE && (TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST)
679+
// a_ptr = input0 + FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, (K_FULL_ITERATIONS * TILE_K));
680+
// #endif
666681
// Loading leftovers of the matrix A and tile C calculation
667682
unroll_for (uint dot_id = 0; dot_id < tile_m_iterations; dot_id++) {
668683
#if INDIRECT_INPUT0
669684
uint a_idx = FUNC_CALL(get_input0_indirect_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, (y + dot_id), (K_FULL_ITERATIONS * TILE_K + sglid), beam_table);
670-
INPUT0_TYPE a_read = input0[a_idx];
671-
#else // INDIRECT_INPUT0
672-
#if K_IS_ALIGNED_4BYTE && (TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST)
673-
INPUT0_TYPE a_read = BLOCK_READ_A(a_ptr, 0);
674-
a_ptr += input0_offset;
685+
// INPUT0_TYPE a_read = input0[a_idx];
686+
// #else // INDIRECT_INPUT0
687+
// #if K_IS_ALIGNED_4BYTE && (TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST)
688+
// INPUT0_TYPE a_read = BLOCK_READ_A(a_ptr, 0);
689+
// a_ptr += input0_offset;
675690
#else
676691
uint a_idx = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, (y + dot_id), (K_FULL_ITERATIONS * TILE_K + sglid));
692+
#endif //--kelvin
677693
INPUT0_TYPE a_read = input0[a_idx];
678-
#endif
679-
#endif // INDIRECT_INPUT0
694+
// #endif
695+
// #endif // INDIRECT_INPUT0
680696
unroll_for (uint simd_id = 0; simd_id < TILE_K_LEFTOVER; simd_id++) {
681697
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_read, simd_id)), b_tile[simd_id], c_tile[dot_id]);
682698
}

src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_tiled_opt.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ GemmKernelTiledOpt::GemmTuningData GemmKernelTiledOpt::SetTuningParams(const gem
9696
tuning_data.tile_m_size = tuning_data.simd_size;
9797
}
9898
// Increasing tile_n_size has performance improvement when m_size and n_size are not shallow and n_size is aligned at 32.
99-
if (m_size >= 128 && n_size >= 128 && (n_size % 32 == 0) && tuning_data.simd_size == 16 && params.fused_ops.empty())
100-
tuning_data.tile_n_size = 32;
99+
// if (m_size >= 128 && n_size >= 128 && (n_size % 32 == 0) && tuning_data.simd_size == 16 && params.fused_ops.empty())
100+
// tuning_data.tile_n_size = 32;
101101

102102
GPU_DEBUG_LOG << params.layerID << ": m_size: " << m_size << ", n_size: " << n_size << ", k_size: " << k_size << std::endl;
103103
} else {
@@ -239,17 +239,17 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
239239
auto leftover_m = m_size % tuning_data.tile_m_size;
240240
auto leftover_n = n_size % tuning_data.tile_n_size;
241241
auto leftover_k = k_size % tuning_data.tile_k_size;
242-
auto n_aligned_4byte = (n_size * BytesPerElement(params.inputs[0].GetDType())) % 4 == 0;
243-
auto k_aligned_4byte = (k_size * BytesPerElement(params.inputs[0].GetDType())) % 4 == 0;
242+
// auto n_aligned_4byte = (n_size * BytesPerElement(params.inputs[0].GetDType())) % 4 == 0;
243+
// auto k_aligned_4byte = (k_size * BytesPerElement(params.inputs[0].GetDType())) % 4 == 0;
244244

245245
jit.AddConstants({
246246
MakeJitConstant("M", m_size),
247247
MakeJitConstant("K", k_size),
248248
MakeJitConstant("N", n_size),
249249
MakeJitConstant("K_PADDED_IN0", k_size),
250250
MakeJitConstant("N_PADDED", n_size),
251-
MakeJitConstant("K_IS_ALIGNED_4BYTE", k_aligned_4byte),
252-
MakeJitConstant("N_IS_ALIGNED_4BYTE", n_aligned_4byte),
251+
// MakeJitConstant("K_IS_ALIGNED_4BYTE", k_aligned_4byte),
252+
// MakeJitConstant("N_IS_ALIGNED_4BYTE", n_aligned_4byte),
253253
MakeJitConstant("SIMD_WIDTH", tuning_data.simd_size),
254254
MakeJitConstant("TILE_M", tuning_data.tile_m_size),
255255
MakeJitConstant("TILE_K", tuning_data.tile_k_size),

0 commit comments

Comments
 (0)