@@ -39,12 +39,15 @@ InitLoops::InitLoops() : Pass() {}
39
39
40
40
void InitLoops::init_ptr_increments (const LinearIR::LoopManager::LoopInfoPtr& loop_info) {
41
41
const auto work_amount = loop_info->get_work_amount ();
42
+ auto loop_entries = loop_info->get_entry_points ();
43
+ auto loop_exits = loop_info->get_exit_points ();
42
44
43
- auto init_entry_port_increment = [&work_amount](LoopPort & loop_entry) {
45
+ for ( auto & loop_entry : loop_entries ) {
44
46
loop_entry.ptr_increment = 0 ;
45
47
if (loop_entry.is_incremented ) {
46
48
const auto & port = loop_entry.expr_port ;
47
49
const auto source = *port->get_connected_ports ().begin ();
50
+ const auto loop_ids = port->get_expr ()->get_loop_ids ();
48
51
const auto & layout = port->get_descriptor_ptr ()->get_layout ();
49
52
const auto & shape = port->get_descriptor_ptr ()->get_shape ();
50
53
const auto & dim = *(layout.rbegin () + loop_entry.dim_idx );
@@ -54,11 +57,13 @@ void InitLoops::init_ptr_increments(const LinearIR::LoopManager::LoopInfoPtr& lo
54
57
loop_entry.ptr_increment = get_input_stride (dim, source.get_descriptor_ptr ()->get_layout (), shape);
55
58
}
56
59
}
57
- };
58
- auto init_exit_port_increment = [&work_amount](LoopPort& loop_exit) {
60
+ }
61
+
62
+ for (auto & loop_exit : loop_exits) {
59
63
loop_exit.ptr_increment = 0 ;
60
64
if (loop_exit.is_incremented ) {
61
65
const auto & port = loop_exit.expr_port ;
66
+ const auto loop_ids = port->get_expr ()->get_loop_ids ();
62
67
const auto & layout = port->get_descriptor_ptr ()->get_layout ();
63
68
const auto & shape = port->get_descriptor_ptr ()->get_shape ();
64
69
const auto original_dim = layout.size () - 1 - loop_exit.dim_idx ;
@@ -69,34 +74,38 @@ void InitLoops::init_ptr_increments(const LinearIR::LoopManager::LoopInfoPtr& lo
69
74
loop_exit.ptr_increment = get_output_stride (dim, shape);
70
75
}
71
76
}
72
- };
73
-
74
- loop_info->update_entry_points (init_entry_port_increment);
75
- loop_info->update_exit_points (init_exit_port_increment);
77
+ }
78
+ loop_info->set_entry_points (loop_entries);
79
+ loop_info->set_exit_points (loop_exits);
76
80
}
77
81
78
82
void InitLoops::init_finalization_offsets (const LinearIR::LoopManager::LoopInfoPtr& loop_info) {
79
83
const auto work_amount = loop_info->get_work_amount ();
80
- auto init_port_finalization_offset = [&work_amount](LoopPort& loop_port) {
81
- loop_port.finalization_offset = -1 * loop_port.ptr_increment * work_amount;
82
- };
83
-
84
- loop_info->update_entry_points (init_port_finalization_offset);
85
- loop_info->update_exit_points (init_port_finalization_offset);
84
+ auto loop_entries = loop_info->get_entry_points ();
85
+ auto loop_exits = loop_info->get_exit_points ();
86
+ for (auto & loop_entry : loop_entries) {
87
+ loop_entry.finalization_offset = -1 * loop_entry.ptr_increment * work_amount;
88
+ }
89
+ for (auto & loop_exit : loop_exits) {
90
+ loop_exit.finalization_offset = -1 * loop_exit.ptr_increment * work_amount;
91
+ }
92
+ loop_info->set_entry_points (loop_entries);
93
+ loop_info->set_exit_points (loop_exits);
86
94
}
87
95
88
96
void InitLoops::init_element_type_sizes (const LinearIR::LoopManager::LoopInfoPtr& loop_info) {
89
- auto init_entry_port_data_size = [](LoopPort& loop_entry) {
97
+ auto loop_entries = loop_info->get_entry_points ();
98
+ auto loop_exits = loop_info->get_exit_points ();
99
+ for (auto & loop_entry : loop_entries) {
90
100
const auto & port = loop_entry.expr_port ;
91
101
loop_entry.data_size = static_cast <int64_t >(port->get_expr ()->get_node ()->get_input_element_type (port->get_index ()).size ());
92
- };
93
- auto init_exit_port_data_size = [](LoopPort & loop_exit) {
102
+ }
103
+ for ( auto & loop_exit : loop_exits ) {
94
104
const auto & port = loop_exit.expr_port ;
95
105
loop_exit.data_size = static_cast <int64_t >(port->get_expr ()->get_node ()->get_output_element_type (port->get_index ()).size ());
96
- };
97
-
98
- loop_info->update_entry_points (init_entry_port_data_size);
99
- loop_info->update_exit_points (init_exit_port_data_size);
106
+ }
107
+ loop_info->set_entry_points (loop_entries);
108
+ loop_info->set_exit_points (loop_exits);
100
109
}
101
110
102
111
bool InitLoops::run (LinearIR& linear_ir) {
0 commit comments