@@ -1266,6 +1266,78 @@ TEST(test_utils_pattern_matcher, ComplexRepetition) {
1266
1266
ASSERT_EQ (fusion_ops.size (), 3U );
1267
1267
}
1268
1268
1269
+ TEST (test_utils_pattern_matcher, SharedInput) {
1270
+ /* Pattern that captures shared input to two MatMuls
1271
+ |
1272
+ / \
1273
+ MatMul MatMul
1274
+ \ /
1275
+ Multiply
1276
+ |
1277
+ */
1278
+ auto graphp = std::make_shared<pb_graph_t >();
1279
+ auto pmm1 = graphp->append_op (MatMul);
1280
+ auto pmm2 = graphp->append_op (MatMul);
1281
+ graphp->create_input_port (0 , pmm1, 0 );
1282
+ graphp->create_input_port (0 , pmm2, 0 );
1283
+ auto pmul = graphp->append_op (
1284
+ Multiply, {in_edge (0 , pmm1, 0 ), in_edge (1 , pmm2, 0 )});
1285
+ UNUSED (pmul);
1286
+
1287
+ // test with a graph that has the shared input
1288
+ graph_t agraph;
1289
+ op_t matmul1 {0 , MatMul, " matmul1" };
1290
+ op_t matmul2 {1 , MatMul, " matmul2" };
1291
+ op_t multiply {2 , Multiply, " multiply" };
1292
+
1293
+ std::vector<logical_tensor_t > lt_vec = create_logical_tensors (6 );
1294
+ matmul1.add_input (lt_vec[0 ]);
1295
+ matmul1.add_input (lt_vec[1 ]);
1296
+ matmul1.add_output (lt_vec[2 ]);
1297
+ matmul2.add_input (lt_vec[0 ]);
1298
+ matmul2.add_input (lt_vec[3 ]);
1299
+ matmul2.add_output (lt_vec[4 ]);
1300
+ multiply.add_input (lt_vec[2 ]);
1301
+ multiply.add_input (lt_vec[4 ]);
1302
+ multiply.add_output (lt_vec[5 ]);
1303
+
1304
+ ASSERT_EQ (agraph.add_op (&matmul1), status::success);
1305
+ ASSERT_EQ (agraph.add_op (&matmul2), status::success);
1306
+ ASSERT_EQ (agraph.add_op (&multiply), status::success);
1307
+ agraph.finalize ();
1308
+
1309
+ std::vector<op_t *> fusion_ops;
1310
+ EXPECT_TRUE (match_pattern (agraph.get_ops ()[0 ].get (), graphp, fusion_ops));
1311
+ ASSERT_EQ (fusion_ops.size (), 3U );
1312
+
1313
+ // test with a graph that does not have the shared input
1314
+ graph_t agraph2;
1315
+ op_t matmul3 {0 , MatMul, " matmul1" };
1316
+ op_t matmul4 {1 , MatMul, " matmul2" };
1317
+ op_t multiply2 {2 , Multiply, " multiply" };
1318
+
1319
+ std::vector<logical_tensor_t > lt_vec2 = create_logical_tensors (7 );
1320
+ matmul3.add_input (lt_vec2[0 ]);
1321
+ matmul3.add_input (lt_vec2[1 ]);
1322
+ matmul3.add_output (lt_vec2[2 ]);
1323
+ matmul4.add_input (lt_vec2[3 ]);
1324
+ matmul4.add_input (lt_vec2[4 ]);
1325
+ matmul4.add_output (lt_vec2[5 ]);
1326
+ multiply2.add_input (lt_vec2[2 ]);
1327
+ multiply2.add_input (lt_vec2[5 ]);
1328
+ multiply2.add_output (lt_vec2[6 ]);
1329
+
1330
+ ASSERT_EQ (agraph2.add_op (&matmul3), status::success);
1331
+ ASSERT_EQ (agraph2.add_op (&matmul4), status::success);
1332
+ ASSERT_EQ (agraph2.add_op (&multiply2), status::success);
1333
+ agraph2.finalize ();
1334
+
1335
+ std::vector<op_t *> fusion_ops2;
1336
+ EXPECT_FALSE (
1337
+ match_pattern (agraph2.get_ops ()[0 ].get (), graphp, fusion_ops2));
1338
+ ASSERT_EQ (fusion_ops2.size (), 0U );
1339
+ }
1340
+
1269
1341
TEST (test_utils_pattern_matcher, ParallelMatmul) {
1270
1342
auto graphp = std::make_shared<pb_graph_t >();
1271
1343
// Pattern that captures shared input to three MatMuls
0 commit comments