Skip to content

Commit 2fa9425

Browse files
gyhintelTaoLv
authored andcommitted
graph: backend: dnnl: fix the dst pointer setting
1 parent 4e37cc4 commit 2fa9425

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/graph/backend/dnnl/kernels/mqa_decomp.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ status_t mqa_decomp_kernel_t<quantized, dt>::execute_impl(
220220
//reorder3
221221
auto &sub_dst_user_tid = res->mem_map[mqa_cfg_.sub_dst_user.get()][tid];
222222

223+
// matmul2
224+
auto &sub_mm2_dst_tid = res->mem_map[mqa_cfg_.sub_mm2_dst.get()][tid];
225+
223226
const size_t sub_src1_offset
224227
= bo * M1 * K1 * get_mem_dt_size(sub_src1_tid);
225228
const size_t sub_wei1_offset = (bo * MBI * K1 * N1 + bi * N1)
@@ -239,6 +242,14 @@ status_t mqa_decomp_kernel_t<quantized, dt>::execute_impl(
239242
sub_dst_user_tid.set_data_handle(
240243
dst2_user_pointer + sub_dst_user_offset);
241244

245+
// If the last reorder is inplace, it means we don't have to do
246+
// extra reorder, thus we should set matmul's output to the user's
247+
// output directly.
248+
if (mqa_cfg_.sub_reorder3.get_inplace()) {
249+
sub_mm2_dst_tid.set_data_handle(
250+
dst2_user_pointer + sub_dst_user_offset);
251+
}
252+
242253
// in parallel region - these primitives should use single thread.
243254
mqa_cfg_.sub_reorder0.execute(strm, res->sub_reorder0_args[tid]);
244255
mqa_cfg_.sub_reorder1.execute(strm, res->sub_reorder1_args[tid]);

src/graph/backend/dnnl/kernels/mqa_decomp_config.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2024 Intel Corporation
2+
* Copyright 2024-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -49,6 +49,8 @@ struct mqa_reorder_t {
4949
return status::success;
5050
}
5151

52+
bool get_inplace() const { return is_inplace_; }
53+
5254
status_t execute(const dnnl::stream &astream,
5355
const std::unordered_map<int, dnnl::memory> &args) const {
5456
// If the src and dst are the same, we just set the src arg to dst

0 commit comments

Comments
 (0)