@@ -787,8 +787,6 @@ status_t fuse_post_ops(std::shared_ptr<subgraph_t> &sg) {
787
787
// lambda function to fuse one post op into base primitive
788
788
auto fuse_post_ops_func = [&](bool &changed) -> status_t {
789
789
auto &mgr = sg->fusion_info_mgr_ ;
790
- subgraph_rewriter_t rewriter (sg);
791
-
792
790
std::vector<std::pair<op_t *, op_t *>> fuse_groups;
793
791
794
792
std::set<op_t *> visited;
@@ -834,7 +832,7 @@ status_t fuse_post_ops(std::shared_ptr<subgraph_t> &sg) {
834
832
changed = false ;
835
833
return status::success;
836
834
}
837
-
835
+ subgraph_rewriter_t rewriter (sg);
838
836
for (auto &fuse_group : fuse_groups) {
839
837
auto base_op = fuse_group.first ;
840
838
auto post_op = fuse_group.second ;
@@ -3365,6 +3363,7 @@ impl::status_t lift_up_typecast(std::shared_ptr<subgraph_t> &sg) {
3365
3363
rewriter.swap_neighboring_si_ops (
3366
3364
producer->shared_from_this (), tc->shared_from_this ());
3367
3365
}
3366
+ rewriter.run ();
3368
3367
}
3369
3368
return infer_shape (sg);
3370
3369
}
@@ -3402,6 +3401,7 @@ impl::status_t lift_up_quantize(std::shared_ptr<subgraph_t> &sg) {
3402
3401
rewriter.swap_neighboring_si_ops (
3403
3402
producer->shared_from_this (), quant->shared_from_this ());
3404
3403
}
3404
+ rewriter.run ();
3405
3405
}
3406
3406
return infer_shape (sg);
3407
3407
}
@@ -3528,7 +3528,7 @@ impl::status_t lift_up_weight_reshape_for_depthwiseconv(
3528
3528
rewriter.swap_neighboring_reshape_ops (
3529
3529
swapped->shared_from_this (), baseop->shared_from_this ());
3530
3530
}
3531
-
3531
+ rewriter. run ();
3532
3532
return infer_shape (sg);
3533
3533
}
3534
3534
@@ -3569,7 +3569,7 @@ impl::status_t fuse_src_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg) {
3569
3569
if (axis < 0 ) axis += ltw (in_lt).ndims ();
3570
3570
}
3571
3571
} else {
3572
- return impl::status::success ;
3572
+ break ;
3573
3573
}
3574
3574
3575
3575
std::vector<int > axes = dnnl_impl::utils::fmap (order,
@@ -3593,6 +3593,7 @@ impl::status_t fuse_src_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg) {
3593
3593
const auto &strides = expected_in_md.get_strides ();
3594
3594
out_val->set_strides (strides);
3595
3595
}
3596
+ rewriter.run ();
3596
3597
return impl::status::success;
3597
3598
}
3598
3599
@@ -3630,7 +3631,7 @@ impl::status_t fuse_dst_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg) {
3630
3631
if (axis < 0 ) axis += ltw (in_lt).ndims ();
3631
3632
}
3632
3633
} else {
3633
- return impl::status::success ;
3634
+ break ;
3634
3635
}
3635
3636
3636
3637
std::vector<int > axes = dnnl_impl::utils::fmap (order,
@@ -3656,14 +3657,14 @@ impl::status_t fuse_dst_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg) {
3656
3657
// Special check to avoid low matmul performance with adbc layout.
3657
3658
// TODO: remove this once the performance is improved.
3658
3659
if (get_format_tag (expected_out_md) == dnnl::memory::format_tag::adbc) {
3659
- return impl::status::success ;
3660
+ break ;
3660
3661
}
3661
3662
const auto &strides = expected_out_md.get_strides ();
3662
3663
in_val->set_strides (strides);
3663
3664
auto &matmul = transpose_op->get_input_value (0 )->get_producer ();
3664
3665
matmul.set_attr (op_attr::keep_dst_layout, true );
3665
3666
}
3666
-
3667
+ rewriter. run ();
3667
3668
return impl::status::success;
3668
3669
}
3669
3670
@@ -3744,6 +3745,7 @@ impl::status_t swap_relu_mul_scales(std::shared_ptr<subgraph_t> &sg) {
3744
3745
rewriter.swap_neighboring_si_ops (
3745
3746
relu->shared_from_this (), mul_scales->shared_from_this ());
3746
3747
}
3748
+ rewriter.run ();
3747
3749
}
3748
3750
return infer_shape (sg);
3749
3751
}
0 commit comments