Skip to content

Commit 19c036c

Browse files
[GPU] output transposed gemm should selects ocl impl instead of onednn impl because it doens't be supported
1 parent c4d6d2b commit 19c036c

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.hpp

+8
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ struct GemmImplementationManager : public ImplementationManager {
7272
if (gemm_prim->indirect_a || gemm_prim->indirect_b)
7373
return false;
7474

75+
// Keep this condition until gemm_onednn supports transposed order of output
76+
const int64_t OTO_SIZE = static_cast<int64_t>(gemm_prim->output_transpose_order.size());
77+
if (OTO_SIZE > 0 &&
78+
!(gemm_prim->output_transpose_order[OTO_SIZE - 2] == (OTO_SIZE - 2) &&
79+
gemm_prim->output_transpose_order[OTO_SIZE - 1] == (OTO_SIZE - 1))) {
80+
return false;
81+
}
82+
7583
return true;
7684
}
7785

src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/transpose_matmul_fusion.cpp

+92
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,95 @@ TEST_P(TransposeMatMulFusionOnGPU, CompareWithRefs){
9696
};
9797

9898
} // namespace
99+
100+
101+
//=================================================================================
102+
// Transpose + MatMul + Transpose pattern fusion (TransposeMatMulTransposeMatcher)
103+
//=================================================================================
104+
namespace ov {
105+
namespace test {
106+
107+
using MatMulTransposeFusionParams = std::tuple<ov::PartialShape, // input A shapes
108+
ov::PartialShape, // input B shapes
109+
ov::PartialShape>; // input C shapes
110+
class MatMulTransposeFusionOnGPU: public testing::WithParamInterface<MatMulTransposeFusionParams>,
111+
virtual public ov::test::SubgraphBaseTest {
112+
public:
113+
static std::string getTestCaseName(testing::TestParamInfo<MatMulTransposeFusionParams> obj) {
114+
ov::PartialShape input0;
115+
ov::PartialShape input1;
116+
ov::PartialShape input2;
117+
118+
std::tie(input0, input1, input2) = obj.param;
119+
120+
std::ostringstream result;
121+
result << "device=(" << std::string(utils::DEVICE_GPU) << ")_";
122+
result << ov::test::utils::partialShape2str({input0}) << "_";
123+
result << ov::test::utils::partialShape2str({input1}) << "_";
124+
result << ov::test::utils::partialShape2str({input2}) << "_";
125+
return result.str();
126+
}
127+
protected:
128+
void SetUp() override {
129+
targetDevice = ov::test::utils::DEVICE_GPU;
130+
131+
ov::PartialShape shape1;
132+
ov::PartialShape shape2;
133+
ov::PartialShape shape3;
134+
135+
std::tie(shape1, shape2, shape3) = GetParam();
136+
137+
InputShape input_shape1 = {shape1, {shape1.get_shape()}};
138+
InputShape input_shape2 = {shape2, {shape2.get_shape()}};
139+
InputShape input_shape3 = {shape3, {shape3.get_shape()}};
140+
init_input_shapes({input_shape1, input_shape2, input_shape3});
141+
142+
const auto param1 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, shape1);
143+
const auto param2 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, shape2);
144+
const auto param3 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, shape3);
145+
146+
auto input2_shape = shape2.get_shape();
147+
148+
//input0
149+
const auto input0_order = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, {1, 0, 2, 3});
150+
const auto input0_transpose = std::make_shared<ov::op::v1::Transpose>(param1, input0_order);
151+
const auto input0_shape_pattern = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, input2_shape);
152+
const auto input0_reshape = std::make_shared<ov::op::v1::Reshape>(input0_transpose, input0_shape_pattern, false);
153+
154+
//input1
155+
const auto input1_order = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, {0, 1, 3, 2});
156+
const auto input1_transpose = std::make_shared<ov::op::v1::Transpose>(param2, input1_order);
157+
158+
// matmul & softmax
159+
const auto matmul1 = std::make_shared<ov::op::v0::MatMul>(input0_reshape, input1_transpose, false, false);
160+
const auto softmax = std::make_shared<ov::op::v8::Softmax>(matmul1, -1);
161+
162+
// input3
163+
const auto input3_transpose = std::make_shared<ov::op::v1::Transpose>(param3, input0_order);
164+
const auto input3_shape_pattern = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, input2_shape);
165+
const auto input3_reshape = std::make_shared<ov::op::v1::Reshape>(input3_transpose, input3_shape_pattern, false);
166+
167+
// target matmul
168+
const auto matmul2 = std::make_shared<ov::op::v0::MatMul>(softmax, input3_reshape, false, false);
169+
const auto order = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, {2, 0, 1, 3});
170+
const auto transpose = std::make_shared<ov::op::v1::Transpose>(matmul2, order);
171+
172+
function = std::make_shared<ov::Model>(transpose, ov::ParameterVector{param1, param2, param3});
173+
}
174+
};
175+
176+
177+
} // namespace test
178+
} // namespace ov
179+
180+
181+
namespace {
182+
INSTANTIATE_TEST_SUITE_P(smoke_MatMulTransposeFusion, MatMulTransposeFusionOnGPU,
183+
::testing::Values(
184+
MatMulTransposeFusionParams({3, 8, 16, 1}, {2, 4, 3, 16}, {3, 8, 16, 1})),
185+
MatMulTransposeFusionOnGPU::getTestCaseName);
186+
187+
TEST_P(MatMulTransposeFusionOnGPU, CompareWithRefs){
188+
run();
189+
};
190+
} // namespace

0 commit comments

Comments
 (0)