@@ -186,8 +186,15 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
186
186
if (ret) {
187
187
tag = convert_data_format (transposed_format);
188
188
dnnl::memory::dims original_dims = dims;
189
- for (size_t i = 0 ; i < original_dims.size (); ++i) {
190
- dims[i] = original_dims[order[i]];
189
+ if (is_input) {
190
+ for (size_t i = 0 ; i < original_dims.size (); ++i) {
191
+ dims[i] = original_dims[order[i]];
192
+ }
193
+ } else {
194
+ // Get non-transposed dims for output dims
195
+ for (size_t i = 0 ; i < original_dims.size (); ++i) {
196
+ dims[order[i]] = original_dims[i];
197
+ }
191
198
}
192
199
} else {
193
200
std::ostringstream ostream;
@@ -272,25 +279,7 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
272
279
273
280
dnnl::memory::desc in0_md = get_input_memory_desc (in0_dims, in0_dt, in0_fmt, in0_strides);
274
281
dnnl::memory::desc in1_md = get_input_memory_desc (in1_dims, in1_dt, in1_fmt, in1_strides);
275
-
276
- // Get non transposed output dims
277
- auto output_order = prim->output_transpose_order ;
278
- bool b_output_transposed = false ;
279
- for (int64_t i = 0 ; i < static_cast <int64_t >(output_order.size ()); i++) {
280
- if (output_order[i] != i) {
281
- b_output_transposed = true ;
282
- break ;
283
- }
284
- }
285
-
286
- dnnl::memory::dims non_transposed_out_dims = out_dims;
287
- if (b_output_transposed) {
288
- for (int64_t i = 0 ; i < static_cast <int64_t >(output_order.size ()); i++) {
289
- non_transposed_out_dims[i] = out_dims[output_order[i]];
290
- }
291
- }
292
-
293
- dnnl::memory::desc out_md (non_transposed_out_dims, out_dt, out_fmt);
282
+ dnnl::memory::desc out_md (out_dims, out_dt, out_fmt);
294
283
295
284
if (gemm_with_bias) {
296
285
dnnl::memory::desc bias_md (bias_dims, bias_dt, bias_fmt);
0 commit comments