@@ -96,3 +96,96 @@ TEST_P(TransposeMatMulFusionOnGPU, CompareWithRefs){
96
96
};
97
97
98
98
} // namespace
99
+
100
+
101
+ // =================================================================================
102
+ // Transpose + MatMul + Transpose pattern fusion (TransposeMatMulTransposeMatcher)
103
+ // =================================================================================
104
+ namespace ov {
105
+ namespace test {
106
+
107
+ using MatMulTransposeFusionParams = std::tuple<ov::PartialShape, // input A shapes
108
+ ov::PartialShape, // input B shapes
109
+ ov::PartialShape>;
110
+ class MatMulTransposeFusionOnGPU : public testing ::WithParamInterface<MatMulTransposeFusionParams>,
111
+ virtual public ov::test::SubgraphBaseTest {
112
+ public:
113
+ static std::string getTestCaseName (testing::TestParamInfo<MatMulTransposeFusionParams> obj) {
114
+ ov::PartialShape input0;
115
+ ov::PartialShape input1;
116
+ ov::PartialShape input2;
117
+
118
+ std::tie (input0, input1, input2) = obj.param ;
119
+
120
+ std::ostringstream result;
121
+ result << " device=(" << std::string (utils::DEVICE_GPU) << " )_" ;
122
+ result << ov::test::utils::partialShape2str ({input0}) << " _" ;
123
+ result << ov::test::utils::partialShape2str ({input1}) << " _" ;
124
+ result << ov::test::utils::partialShape2str ({input2}) << " _" ;
125
+ return result.str ();
126
+ }
127
+ protected:
128
+ void SetUp () override {
129
+ targetDevice = ov::test::utils::DEVICE_GPU;
130
+
131
+ ov::PartialShape shape1;
132
+ ov::PartialShape shape2;
133
+ ov::PartialShape shape3;
134
+
135
+ std::tie (shape1, shape2, shape3) = GetParam ();
136
+
137
+ InputShape input_shape1 = {shape1, {shape1.get_shape ()}};
138
+ InputShape input_shape2 = {shape2, {shape2.get_shape ()}};
139
+ InputShape input_shape3 = {shape3, {shape3.get_shape ()}};
140
+ init_input_shapes ({input_shape1, input_shape2, input_shape3});
141
+
142
+ const auto param1 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, shape1);
143
+ const auto param2 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, shape2);
144
+ const auto param3 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, shape3);
145
+
146
+ auto input2_shape = shape2.get_shape ();
147
+
148
+ // input0
149
+ const auto input0_order = ov::op::v0::Constant::create (ov::element::i32, Shape{4 }, {1 , 0 , 2 , 3 });
150
+ const auto input0_transpose = std::make_shared<ov::op::v1::Transpose>(param1, input0_order);
151
+ const auto input0_shape_pattern = ov::op::v0::Constant::create (ov::element::i32, Shape{4 }, input2_shape);
152
+ const auto input0_reshape = std::make_shared<ov::op::v1::Reshape>(input0_transpose, input0_shape_pattern, false );
153
+
154
+ // input1
155
+ const auto input1_order = ov::op::v0::Constant::create (ov::element::i32, Shape{4 }, {0 , 1 , 3 , 2 });
156
+ const auto input1_transpose = std::make_shared<ov::op::v1::Transpose>(param2, input1_order);
157
+
158
+ // matmul & softmax
159
+ const auto matmul1 = std::make_shared<ov::op::v0::MatMul>(input0_reshape, input1_transpose, false , false );
160
+ const auto softmax = std::make_shared<ov::op::v8::Softmax>(matmul1, -1 );
161
+
162
+ // input3
163
+ // const auto input3_order = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, {1, 0, 2, 3});
164
+ const auto input3_transpose = std::make_shared<ov::op::v1::Transpose>(param3, input0_order);
165
+ const auto input3_shape_pattern = ov::op::v0::Constant::create (ov::element::i32, Shape{4 }, input2_shape);
166
+ const auto input3_reshape = std::make_shared<ov::op::v1::Reshape>(input3_transpose, input3_shape_pattern, false );
167
+
168
+ // target matmul
169
+ const auto matmul2 = std::make_shared<ov::op::v0::MatMul>(softmax, input3_reshape, false , false );
170
+ const auto order = ov::op::v0::Constant::create (ov::element::i32, Shape{4 }, {2 , 0 , 1 , 3 });
171
+ const auto transpose = std::make_shared<ov::op::v1::Transpose>(matmul2, order);
172
+
173
+ function = std::make_shared<ov::Model>(transpose, ov::ParameterVector{param1, param2, param3});
174
+ }
175
+ };
176
+
177
+
178
+ } // namespace test
179
+ } // namespace ov
180
+
181
+
182
+ namespace {
183
+ INSTANTIATE_TEST_SUITE_P (smoke_MatMulTransposeFusion, MatMulTransposeFusionOnGPU,
184
+ ::testing::Values (
185
+ MatMulTransposeFusionParams ({3 , 8 , 16 , 1 }, {2 , 4 , 3 , 16 }, {3 , 8 , 16 , 1 })),
186
+ MatMulTransposeFusionOnGPU::getTestCaseName);
187
+
188
+ TEST_P (MatMulTransposeFusionOnGPU, CompareWithRefs){
189
+ run ();
190
+ };
191
+ } // namespace
0 commit comments