Skip to content

Commit 7ba14f6

Browse files
committed
gpu:cuda: Fix matmul parameters for inner_product usages
1 parent 7e450f8 commit 7ba14f6

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

src/gpu/nvidia/cudnn_matmul_executor.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -392,12 +392,12 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t {
392392
memory_tracking::names::key_matmul_dst_in_acc_dt)
393393
: xpu::sycl::interop_memory_arg_t<
394394
::sycl::access::mode::read_write>();
395-
auto arg_block_a_scratch = params->source_size_ != 0
395+
auto arg_block_a_scratch = params->weight_size_ != 0
396396
? CTX_SCRATCH_SYCL_MEMORY(
397397
memory_tracking::names::key_gemm_blocked_a)
398398
: xpu::sycl::interop_memory_arg_t<
399399
::sycl::access::mode::read_write>();
400-
auto arg_block_b_scratch = params->weight_size_ != 0
400+
auto arg_block_b_scratch = params->source_size_ != 0
401401
? CTX_SCRATCH_SYCL_MEMORY(
402402
memory_tracking::names::key_gemm_blocked_b)
403403
: xpu::sycl::interop_memory_arg_t<
@@ -457,10 +457,10 @@ struct cudnn_matmul_lt_runtime_args_exec_t final
457457
matmul_params->reorder_scratch_size_, cuda_stream->queue());
458458

459459
uint8_t *block_a_scratch_ptr
460-
= alloc_ptr(matmul_params->source_size_, cuda_stream->queue());
460+
= alloc_ptr(matmul_params->weight_size_, cuda_stream->queue());
461461

462462
uint8_t *block_b_scratch_ptr
463-
= alloc_ptr(matmul_params->weight_size_, cuda_stream->queue());
463+
= alloc_ptr(matmul_params->source_size_, cuda_stream->queue());
464464

465465
uint8_t *block_c_scratch_ptr
466466
= alloc_ptr(matmul_params->dest_size_, cuda_stream->queue());

src/gpu/nvidia/cudnn_matmul_lt_impl.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ struct cudnn_matmul_lt_impl_t {
717717
}
718718
if (!params->w_blocked_) {
719719
transform_matrix(lt_handle, params, a_layout, a,
720-
blocked_a_layout, block_a_scratch, !params->trans_a_,
720+
blocked_a_layout, block_a_scratch, params->trans_a_,
721721
streamId);
722722
a = block_a_scratch;
723723
}

0 commit comments

Comments
 (0)