@@ -392,12 +392,12 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t {
392
392
memory_tracking::names::key_matmul_dst_in_acc_dt)
393
393
: xpu::sycl::interop_memory_arg_t <
394
394
::sycl::access ::mode::read_write>();
395
- auto arg_block_a_scratch = params->source_size_ != 0
395
+ auto arg_block_a_scratch = params->weight_size_ != 0
396
396
? CTX_SCRATCH_SYCL_MEMORY (
397
397
memory_tracking::names::key_gemm_blocked_a)
398
398
: xpu::sycl::interop_memory_arg_t <
399
399
::sycl::access ::mode::read_write>();
400
- auto arg_block_b_scratch = params->weight_size_ != 0
400
+ auto arg_block_b_scratch = params->source_size_ != 0
401
401
? CTX_SCRATCH_SYCL_MEMORY (
402
402
memory_tracking::names::key_gemm_blocked_b)
403
403
: xpu::sycl::interop_memory_arg_t <
@@ -457,10 +457,10 @@ struct cudnn_matmul_lt_runtime_args_exec_t final
457
457
matmul_params->reorder_scratch_size_ , cuda_stream->queue ());
458
458
459
459
uint8_t *block_a_scratch_ptr
460
- = alloc_ptr (matmul_params->source_size_ , cuda_stream->queue ());
460
+ = alloc_ptr (matmul_params->weight_size_ , cuda_stream->queue ());
461
461
462
462
uint8_t *block_b_scratch_ptr
463
- = alloc_ptr (matmul_params->weight_size_ , cuda_stream->queue ());
463
+ = alloc_ptr (matmul_params->source_size_ , cuda_stream->queue ());
464
464
465
465
uint8_t *block_c_scratch_ptr
466
466
= alloc_ptr (matmul_params->dest_size_ , cuda_stream->queue ());
0 commit comments