Skip to content

Commit 2f2c380

Browse files
ElaineBaoTaoLv
authored andcommitted
gtests: graph: unit: add gtests for verifying multi-consumers in pm
1 parent 3d64643 commit 2f2c380

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

tests/gtests/graph/unit/utils/test_pattern_matcher_cpu.cpp

+72
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,78 @@ TEST(test_utils_pattern_matcher, ComplexRepetition) {
12661266
ASSERT_EQ(fusion_ops.size(), 3U);
12671267
}
12681268

1269+
TEST(test_utils_pattern_matcher, SharedInput) {
1270+
/* Pattern that captures shared input to two MatMuls
1271+
|
1272+
/ \
1273+
MatMul MatMul
1274+
\ /
1275+
Multiply
1276+
|
1277+
*/
1278+
auto graphp = std::make_shared<pb_graph_t>();
1279+
auto pmm1 = graphp->append_op(MatMul);
1280+
auto pmm2 = graphp->append_op(MatMul);
1281+
graphp->create_input_port(0, pmm1, 0);
1282+
graphp->create_input_port(0, pmm2, 0);
1283+
auto pmul = graphp->append_op(
1284+
Multiply, {in_edge(0, pmm1, 0), in_edge(1, pmm2, 0)});
1285+
UNUSED(pmul);
1286+
1287+
// test with a graph that has the shared input
1288+
graph_t agraph;
1289+
op_t matmul1 {0, MatMul, "matmul1"};
1290+
op_t matmul2 {1, MatMul, "matmul2"};
1291+
op_t multiply {2, Multiply, "multiply"};
1292+
1293+
std::vector<logical_tensor_t> lt_vec = create_logical_tensors(6);
1294+
matmul1.add_input(lt_vec[0]);
1295+
matmul1.add_input(lt_vec[1]);
1296+
matmul1.add_output(lt_vec[2]);
1297+
matmul2.add_input(lt_vec[0]);
1298+
matmul2.add_input(lt_vec[3]);
1299+
matmul2.add_output(lt_vec[4]);
1300+
multiply.add_input(lt_vec[2]);
1301+
multiply.add_input(lt_vec[4]);
1302+
multiply.add_output(lt_vec[5]);
1303+
1304+
ASSERT_EQ(agraph.add_op(&matmul1), status::success);
1305+
ASSERT_EQ(agraph.add_op(&matmul2), status::success);
1306+
ASSERT_EQ(agraph.add_op(&multiply), status::success);
1307+
agraph.finalize();
1308+
1309+
std::vector<op_t *> fusion_ops;
1310+
EXPECT_TRUE(match_pattern(agraph.get_ops()[0].get(), graphp, fusion_ops));
1311+
ASSERT_EQ(fusion_ops.size(), 3U);
1312+
1313+
// test with a graph that does not have the shared input
1314+
graph_t agraph2;
1315+
op_t matmul3 {0, MatMul, "matmul1"};
1316+
op_t matmul4 {1, MatMul, "matmul2"};
1317+
op_t multiply2 {2, Multiply, "multiply"};
1318+
1319+
std::vector<logical_tensor_t> lt_vec2 = create_logical_tensors(7);
1320+
matmul3.add_input(lt_vec2[0]);
1321+
matmul3.add_input(lt_vec2[1]);
1322+
matmul3.add_output(lt_vec2[2]);
1323+
matmul4.add_input(lt_vec2[3]);
1324+
matmul4.add_input(lt_vec2[4]);
1325+
matmul4.add_output(lt_vec2[5]);
1326+
multiply2.add_input(lt_vec2[2]);
1327+
multiply2.add_input(lt_vec2[5]);
1328+
multiply2.add_output(lt_vec2[6]);
1329+
1330+
ASSERT_EQ(agraph2.add_op(&matmul3), status::success);
1331+
ASSERT_EQ(agraph2.add_op(&matmul4), status::success);
1332+
ASSERT_EQ(agraph2.add_op(&multiply2), status::success);
1333+
agraph2.finalize();
1334+
1335+
std::vector<op_t *> fusion_ops2;
1336+
EXPECT_FALSE(
1337+
match_pattern(agraph2.get_ops()[0].get(), graphp, fusion_ops2));
1338+
ASSERT_EQ(fusion_ops2.size(), 0U);
1339+
}
1340+
12691341
TEST(test_utils_pattern_matcher, ParallelMatmul) {
12701342
auto graphp = std::make_shared<pb_graph_t>();
12711343
// Pattern that captures shared input to three MatMuls

0 commit comments

Comments
 (0)