Skip to content

Commit c9c888f

Browse files
[GPU] Get non-transposed output shape for gemm onednn
1 parent a29d4c1 commit c9c888f

File tree

1 file changed

+10
-21
lines changed

1 file changed

+10
-21
lines changed

src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp

+10-21
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,15 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
186186
if (ret) {
187187
tag = convert_data_format(transposed_format);
188188
dnnl::memory::dims original_dims = dims;
189-
for (size_t i = 0; i < original_dims.size(); ++i) {
190-
dims[i] = original_dims[order[i]];
189+
if (is_input) {
190+
for (size_t i = 0; i < original_dims.size(); ++i) {
191+
dims[i] = original_dims[order[i]];
192+
}
193+
} else {
194+
// Get non-transposed dims for output dims
195+
for (size_t i = 0; i < original_dims.size(); ++i) {
196+
dims[order[i]] = original_dims[i];
197+
}
191198
}
192199
} else {
193200
std::ostringstream ostream;
@@ -272,25 +279,7 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
272279

273280
dnnl::memory::desc in0_md = get_input_memory_desc(in0_dims, in0_dt, in0_fmt, in0_strides);
274281
dnnl::memory::desc in1_md = get_input_memory_desc(in1_dims, in1_dt, in1_fmt, in1_strides);
275-
276-
// Get non transposed output dims
277-
auto output_order = prim->output_transpose_order;
278-
bool b_output_transposed = false;
279-
for (int64_t i = 0; i < static_cast<int64_t>(output_order.size()); i++) {
280-
if (output_order[i] != i) {
281-
b_output_transposed = true;
282-
break;
283-
}
284-
}
285-
286-
dnnl::memory::dims non_transposed_out_dims = out_dims;
287-
if (b_output_transposed) {
288-
for (int64_t i = 0; i < static_cast<int64_t>(output_order.size()); i++) {
289-
non_transposed_out_dims[i] = out_dims[output_order[i]];
290-
}
291-
}
292-
293-
dnnl::memory::desc out_md(non_transposed_out_dims, out_dt, out_fmt);
282+
dnnl::memory::desc out_md(out_dims, out_dt, out_fmt);
294283

295284
if (gemm_with_bias) {
296285
dnnl::memory::desc bias_md(bias_dims, bias_dt, bias_fmt);

0 commit comments

Comments
 (0)