@@ -142,6 +142,8 @@ class GemmFusingTest : public ::BaseFusingTest<gemm_test_params> {
142
142
#define CASE_GEMM_PERMUTES_FUSION_FP16_3 { { 17 , 11 , 2 , 18 }, { 17 , 11 , 18 , 4 } }, { 17 , 11 , 2 , 4 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
143
143
#define CASE_GEMM_PERMUTES_FUSION_FP16_4 { { 3 , 2 , 10 , 12 }, { 3 , 2 , 12 , 20 } }, { 3 , 2 , 10 , 20 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
144
144
#define CASE_GEMM_PERMUTES_FUSION_FP16_5 { { 3 , 2 , 16 , 32 }, { 3 , 2 , 32 , 16 } }, { 3 , 2 , 16 , 16 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
145
+ #define CASE_GEMM_PERMUTES_FUSION_FP16_6 { { 3 , 2 , 16 , 32 }, { 3 , 16 , 2 , 32 } }, { 3 , 2 , 2 , 32 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
146
+
145
147
class gemm_3in_quantize_i8 : public GemmFusingTest {};
146
148
TEST_P (gemm_3in_quantize_i8, basic) {
147
149
// TODO: Fix me, refer PR(#15873)
@@ -757,4 +759,40 @@ INSTANTIATE_TEST_SUITE_P(
757
759
gemm_test_params{CASE_GEMM_PERMUTES_FUSION_FP16_3, 3 , 6 , " " , broadcast_kinds::feature/* dummy*/ , eltwise_mode::sum/* dummy*/ , {{0 , 2 , 1 , 3 } /* byfx*/ , {1 , 2 , 3 , 0 } /* xbfy*/ , {0 , 2 , 1 , 3 } /* byfx*/ }},
758
760
}));
759
761
762
+ class permute_gemm_reorder : public GemmFusingTestOneDNN {};
763
+ TEST_P (permute_gemm_reorder, fused_permute_gemm_with_reorder) {
764
+ auto p = GetParam ();
765
+ auto in_lay0 = get_input_layout (p, 0 );
766
+ auto in_lay1 = get_input_layout (p, 1 );
767
+ auto permute_in_lay0 = get_permute_input_shape (in_lay0.get_shape (), p.permute_orders [0 ]);
768
+ auto permute_in_lay1 = get_permute_input_shape (in_lay1.get_shape (), p.permute_orders [1 ]);
769
+ in_lay0.set_partial_shape (permute_in_lay0);
770
+ in_lay1.set_partial_shape (permute_in_lay1);
771
+ create_topologies (
772
+ input_layout (" input0" , in_lay0),
773
+ input_layout (" input1" , in_lay1),
774
+ permute (" permute0" , input_info (" input0" ), p.permute_orders [0 ]),
775
+ reorder (" reorder_permute" , input_info (" permute0" ), p.default_format , data_types::f32),
776
+ permute (" permute1" , input_info (" input1" ), p.permute_orders [1 ]),
777
+ gemm (" gemm_prim" , { input_info (" permute0" ), input_info (" permute1" ) }, data_types::f16),
778
+ reorder (" reorder_bfyx" , input_info (" gemm_prim" ), p.default_format , data_types::f32),
779
+ eltwise (" eltwise" , { input_info (" reorder_permute" ), input_info (" reorder_bfyx" ) }, eltwise_mode::sum, data_types::f32)
780
+ );
781
+
782
+ tolerance = default_tolerance (data_types::f16);
783
+ execute (p, false );
784
+ }
785
+
786
+ #define CASE_PERMUTES_GEMM_FUSION_FP16_1 { { 1 , 12 , 20 , 64 }, { 1 , 12 , 64 , 64 } }, { 1 , 12 , 20 , 64 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
787
+ #define CASE_PERMUTES_GEMM_FUSION_FP16_2 { { 3 , 2 , 10 , 12 }, { 3 , 2 , 12 , 1 } }, { 3 , 2 , 10 , 1 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
788
+
789
+ INSTANTIATE_TEST_SUITE_P (
790
+ fusings_gpu, permute_gemm_reorder, ::testing::ValuesIn(std::vector<gemm_test_params>{
791
+ gemm_test_params{CASE_PERMUTES_GEMM_FUSION_FP16_1, 4 , 6 , " " , broadcast_kinds::feature/* dummy*/ , eltwise_mode::sum/* dummy*/ , {{0 , 2 , 1 , 3 } /* byfx*/ , {0 , 2 , 1 , 3 } /* byfx*/ }},
792
+ gemm_test_params{CASE_PERMUTES_GEMM_FUSION_FP16_1, 4 , 6 , " " , broadcast_kinds::feature/* dummy*/ , eltwise_mode::sum/* dummy*/ , {{0 , 2 , 1 , 3 } /* byfx*/ , {1 , 2 , 3 , 0 } /* xbfy*/ }},
793
+ gemm_test_params{CASE_PERMUTES_GEMM_FUSION_FP16_2, 4 , 6 , " " , broadcast_kinds::feature/* dummy*/ , eltwise_mode::sum/* dummy*/ , {{0 , 2 , 1 , 3 } /* byfx*/ , {0 , 2 , 1 , 3 } /* byfx*/ }},
794
+ gemm_test_params{CASE_PERMUTES_GEMM_FUSION_FP16_2, 4 , 6 , " " , broadcast_kinds::feature/* dummy*/ , eltwise_mode::sum/* dummy*/ , {{0 , 2 , 1 , 3 } /* byfx*/ , {1 , 2 , 3 , 0 } /* xbfy*/ }},
795
+ }));
796
+
797
+
760
798
#endif // ENABLE_ONEDNN_FOR_GPU
0 commit comments