Skip to content

Commit 0185923

Browse files
committed
graph: backend: dnnl: run graph rewriter explicitly
1 parent e852cdd commit 0185923

File tree

5 files changed

+11
-18
lines changed

5 files changed

+11
-18
lines changed

src/graph/backend/dnnl/kernels/mqa_decomp_config.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ status_t mqa_decomp_config_t::record_input_offset(
358358
}
359359

360360
status_t mqa_decomp_config_t::record_mqa_ops(std::shared_ptr<subgraph_t> &sg) {
361-
subgraph_rewriter_t rewriter(sg);
362361
op_ptr reorder1, reorder2, matmul1, softmax, matmul2;
363362
for (const auto &cur_op : sg->get_ops()) {
364363
if (cur_op->get_kind() != op_kind::dnnl_matmul) continue;

src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,6 @@ impl::status_t sdp_decomp_config_t::record_sdp_ops(
511511
return nullptr;
512512
};
513513

514-
subgraph_rewriter_t rewriter(sg);
515-
516514
for (const auto &cur_op : sg->get_ops()) {
517515
if (!cur_op || cur_op->get_kind() != op_kind::dnnl_matmul) continue;
518516
auto post_op = get_post_op(cur_op);

src/graph/backend/dnnl/passes/transform.cpp

+10-8
Original file line numberDiff line numberDiff line change
@@ -787,8 +787,6 @@ status_t fuse_post_ops(std::shared_ptr<subgraph_t> &sg) {
787787
// lambda function to fuse one post op into base primitive
788788
auto fuse_post_ops_func = [&](bool &changed) -> status_t {
789789
auto &mgr = sg->fusion_info_mgr_;
790-
subgraph_rewriter_t rewriter(sg);
791-
792790
std::vector<std::pair<op_t *, op_t *>> fuse_groups;
793791

794792
std::set<op_t *> visited;
@@ -834,7 +832,7 @@ status_t fuse_post_ops(std::shared_ptr<subgraph_t> &sg) {
834832
changed = false;
835833
return status::success;
836834
}
837-
835+
subgraph_rewriter_t rewriter(sg);
838836
for (auto &fuse_group : fuse_groups) {
839837
auto base_op = fuse_group.first;
840838
auto post_op = fuse_group.second;
@@ -3365,6 +3363,7 @@ impl::status_t lift_up_typecast(std::shared_ptr<subgraph_t> &sg) {
33653363
rewriter.swap_neighboring_si_ops(
33663364
producer->shared_from_this(), tc->shared_from_this());
33673365
}
3366+
rewriter.run();
33683367
}
33693368
return infer_shape(sg);
33703369
}
@@ -3402,6 +3401,7 @@ impl::status_t lift_up_quantize(std::shared_ptr<subgraph_t> &sg) {
34023401
rewriter.swap_neighboring_si_ops(
34033402
producer->shared_from_this(), quant->shared_from_this());
34043403
}
3404+
rewriter.run();
34053405
}
34063406
return infer_shape(sg);
34073407
}
@@ -3528,7 +3528,7 @@ impl::status_t lift_up_weight_reshape_for_depthwiseconv(
35283528
rewriter.swap_neighboring_reshape_ops(
35293529
swapped->shared_from_this(), baseop->shared_from_this());
35303530
}
3531-
3531+
rewriter.run();
35323532
return infer_shape(sg);
35333533
}
35343534

@@ -3569,7 +3569,7 @@ impl::status_t fuse_src_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg) {
35693569
if (axis < 0) axis += ltw(in_lt).ndims();
35703570
}
35713571
} else {
3572-
return impl::status::success;
3572+
break;
35733573
}
35743574

35753575
std::vector<int> axes = dnnl_impl::utils::fmap(order,
@@ -3593,6 +3593,7 @@ impl::status_t fuse_src_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg) {
35933593
const auto &strides = expected_in_md.get_strides();
35943594
out_val->set_strides(strides);
35953595
}
3596+
rewriter.run();
35963597
return impl::status::success;
35973598
}
35983599

@@ -3630,7 +3631,7 @@ impl::status_t fuse_dst_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg) {
36303631
if (axis < 0) axis += ltw(in_lt).ndims();
36313632
}
36323633
} else {
3633-
return impl::status::success;
3634+
break;
36343635
}
36353636

36363637
std::vector<int> axes = dnnl_impl::utils::fmap(order,
@@ -3656,14 +3657,14 @@ impl::status_t fuse_dst_transpose_to_matmul(std::shared_ptr<subgraph_t> &sg) {
36563657
// Special check to avoid low matmul performance with adbc layout.
36573658
// TODO: remove this once the performance is improved.
36583659
if (get_format_tag(expected_out_md) == dnnl::memory::format_tag::adbc) {
3659-
return impl::status::success;
3660+
break;
36603661
}
36613662
const auto &strides = expected_out_md.get_strides();
36623663
in_val->set_strides(strides);
36633664
auto &matmul = transpose_op->get_input_value(0)->get_producer();
36643665
matmul.set_attr(op_attr::keep_dst_layout, true);
36653666
}
3666-
3667+
rewriter.run();
36673668
return impl::status::success;
36683669
}
36693670

@@ -3744,6 +3745,7 @@ impl::status_t swap_relu_mul_scales(std::shared_ptr<subgraph_t> &sg) {
37443745
rewriter.swap_neighboring_si_ops(
37453746
relu->shared_from_this(), mul_scales->shared_from_this());
37463747
}
3748+
rewriter.run();
37473749
}
37483750
return infer_shape(sg);
37493751
}

src/graph/backend/dnnl/subgraph.cpp

-4
Original file line numberDiff line numberDiff line change
@@ -384,10 +384,6 @@ void subgraph_rewriter_t::run() {
384384
to_be_inserted_ops_.clear();
385385
}
386386

387-
subgraph_rewriter_t::~subgraph_rewriter_t() {
388-
run();
389-
}
390-
391387
void subgraph_rewriter_t::fuse_op_to_successor(const op_ptr &op) {
392388
assertm(op->num_inputs() == 1, "this op should have only one input value.");
393389
value_ptr in_val = op->get_input_value(0);

src/graph/backend/dnnl/subgraph.hpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2022-2023 Intel Corporation
2+
* Copyright 2022-2024 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -140,8 +140,6 @@ class subgraph_rewriter_t {
140140
public:
141141
subgraph_rewriter_t(std::shared_ptr<subgraph_t> &sg) : subgraph_(sg) {}
142142

143-
~subgraph_rewriter_t();
144-
145143
// Finalize the rewriting, which actually insert/remove the op to/from
146144
// subgraph op list
147145
void run();

0 commit comments

Comments
 (0)