Skip to content

Commit 399f68f

Browse files
[GPU] Use non-transposed output shape for gemm onednn
1 parent a69e0c2 commit 399f68f

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp

+19-1
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,25 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
272272

273273
dnnl::memory::desc in0_md = get_input_memory_desc(in0_dims, in0_dt, in0_fmt, in0_strides);
274274
dnnl::memory::desc in1_md = get_input_memory_desc(in1_dims, in1_dt, in1_fmt, in1_strides);
275-
dnnl::memory::desc out_md(out_dims, out_dt, out_fmt);
275+
276+
// Get non transposed output dims
277+
auto output_order = prim->output_transpose_order;
278+
bool b_output_transposed = false;
279+
for (int64_t i = 0; i < static_cast<int64_t>(output_order.size()); i++) {
280+
if (output_order[i] != i) {
281+
b_output_transposed = true;
282+
break;
283+
}
284+
}
285+
286+
dnnl::memory::dims non_transposed_out_dims = out_dims;
287+
if (b_output_transposed) {
288+
for (int64_t i = 0; i < static_cast<int64_t>(output_order.size()); i++) {
289+
non_transposed_out_dims[i] = out_dims[output_order[i]];
290+
}
291+
}
292+
293+
dnnl::memory::desc out_md(non_transposed_out_dims, out_dt, out_fmt);
276294

277295
if (gemm_with_bias) {
278296
dnnl::memory::desc bias_md(bias_dims, bias_dt, bias_fmt);

src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.hpp

-8
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,6 @@ struct GemmImplementationManager : public ImplementationManager {
7272
if (gemm_prim->indirect_a || gemm_prim->indirect_b)
7373
return false;
7474

75-
// Keep this condition until gemm_onednn supports transposed order of output
76-
const int64_t OTO_SIZE = static_cast<int64_t>(gemm_prim->output_transpose_order.size());
77-
if (OTO_SIZE > 0 &&
78-
!(gemm_prim->output_transpose_order[OTO_SIZE - 2] == (OTO_SIZE - 2) &&
79-
gemm_prim->output_transpose_order[OTO_SIZE - 1] == (OTO_SIZE - 1))) {
80-
return false;
81-
}
82-
8375
return true;
8476
}
8577

src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/transpose_matmul_fusion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulTransposeFusion, MatMulTransposeFusionOnGPU
187187
TEST_P(MatMulTransposeFusionOnGPU, CompareWithRefs){
188188
run();
189189
};
190-
} // namespace
190+
} // namespace

0 commit comments

Comments
 (0)