@@ -124,6 +124,64 @@ TEST_F(TransformationTestsF, TranposeMatmulFusion4) {
124
124
}
125
125
}
126
126
127
+ TEST_F (TransformationTestsF, TranposeMatmulFusion5) {
128
+ {
129
+ auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic (3 ));
130
+ auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic (3 ));
131
+ auto matmul = std::make_shared<ov::op::v0::MatMul>(input_a, input_b);
132
+ auto tranpose_c_const = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 2 , 1 });
133
+ auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(matmul, tranpose_c_const);
134
+
135
+ model = std::make_shared<ov::Model>(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b });
136
+
137
+ const auto supports_immad = false ;
138
+ manager.register_pass <TransposeFusion>(supports_immad);
139
+ }
140
+ {
141
+ std::vector<int64_t > order_a = {0 , 1 , 2 };
142
+ std::vector<int64_t > order_b = {0 , 1 , 2 };
143
+ std::vector<int64_t > order_c = {0 , 2 , 1 };
144
+ auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic (3 ));
145
+ auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic (3 ));
146
+ auto gemm = std::make_shared<ov::intel_gpu::op::Gemm>(input_a, input_b, order_a, order_b, order_c, ov::element::undefined);
147
+
148
+ model_ref = std::make_shared<ov::Model>(ov::NodeVector{ gemm }, ov::ParameterVector{ input_a, input_b });
149
+ comparator.enable (FunctionsComparator::ATTRIBUTES);
150
+ }
151
+ }
152
+
153
+ TEST_F (TransformationTestsF, TranposeMatmulFusion6) {
154
+ auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic (2 ));
155
+ auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape::dynamic (2 ));
156
+ auto matmul = std::make_shared<ov::op::v0::MatMul>(input_a, input_b);
157
+ auto tranpose_c_const = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{2 }, {1 , 0 });
158
+ auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(matmul, tranpose_c_const);
159
+
160
+ model = std::make_shared<ov::Model>(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b });
161
+
162
+ const auto supports_immad = false ;
163
+ manager.register_pass <TransposeFusion>(supports_immad);
164
+
165
+ model_ref = model->clone ();
166
+ comparator.enable (FunctionsComparator::ATTRIBUTES);
167
+ }
168
+
169
+ TEST_F (TransformationTestsF, TranposeMatmulFusion7) {
170
+ auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{2 , 4 });
171
+ auto input_b = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{4 , 2 });
172
+ auto matmul = std::make_shared<ov::op::v0::MatMul>(input_a, input_b);
173
+ auto tranpose_c_const = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{2 }, {1 , 0 });
174
+ auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(matmul, tranpose_c_const);
175
+
176
+ model = std::make_shared<ov::Model>(ov::NodeVector{ tranpose_c }, ov::ParameterVector{ input_a, input_b });
177
+
178
+ const auto supports_immad = false ;
179
+ manager.register_pass <TransposeFusion>(supports_immad);
180
+
181
+ model_ref = model->clone ();
182
+ comparator.enable (FunctionsComparator::ATTRIBUTES);
183
+ }
184
+
127
185
TEST_F (TransformationTestsF, TranposeMatmulFusion_Illegal_1) {
128
186
{
129
187
auto input_a = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{10 , 20 });
0 commit comments