@@ -22,6 +22,12 @@ struct typed_program_node<loop> : public typed_program_node_base<loop> {
22
22
private:
23
23
using parent = typed_program_node_base<loop>;
24
24
25
+ primitive_id trip_count_id;
26
+ primitive_id initial_execution_id;
27
+ primitive_id current_iteration_id;
28
+ primitive_id execution_condition_id;
29
+ primitive_id num_iterations_id;
30
+
25
31
std::vector<loop::io_primitive_map>& input_primitive_maps;
26
32
std::vector<loop::io_primitive_map>& output_primitive_maps;
27
33
std::vector<loop::backedge_mapping>& back_edges;
@@ -31,21 +37,32 @@ struct typed_program_node<loop> : public typed_program_node_base<loop> {
31
37
parent (prim, prog),
32
38
input_primitive_maps (prim->input_primitive_maps),
33
39
output_primitive_maps (prim->output_primitive_maps),
34
- back_edges (prim->back_edges) {}
40
+ back_edges (prim->back_edges) {
41
+ set_primitive_ids (prim);
42
+ }
35
43
36
44
program::ptr get_body_program () const { return get_primitive ()->body_program ; }
37
45
38
- const primitive_id& get_trip_count_id () const { return get_primitive ()->trip_count_id ; }
39
- const primitive_id& get_initial_execution_id () const { return get_primitive ()->first_execution_condition_id ; }
40
- const primitive_id& get_current_iteration_id () const { return get_primitive ()->body_current_iteration_id ; }
41
- const primitive_id& get_execution_condition_id () const { return get_primitive ()->body_execution_condition_id ; }
42
- const primitive_id& get_num_iterations_id () const { return get_primitive ()->num_iteration_id ; }
46
+ const primitive_id& get_trip_count_id () const { return trip_count_id; }
47
+ const primitive_id& get_initial_execution_id () const { return initial_execution_id; }
48
+ const primitive_id& get_current_iteration_id () const { return current_iteration_id; }
49
+ const primitive_id& get_execution_condition_id () const { return execution_condition_id; }
50
+ const primitive_id& get_num_iterations_id () const { return num_iterations_id; }
51
+
43
52
const int32_t get_max_num_iteration () const { return get_primitive ()->max_num_iterations ; }
44
53
45
54
const std::vector<loop::io_primitive_map>& get_input_primitive_maps () const { return input_primitive_maps; }
46
55
const std::vector<loop::io_primitive_map>& get_output_primitive_maps () const { return output_primitive_maps; }
47
56
const std::vector<loop::backedge_mapping>& get_back_edges () const { return back_edges;}
48
57
58
+ void set_primitive_ids (std::shared_ptr<loop> prim) {
59
+ trip_count_id = prim->trip_count_id ;
60
+ initial_execution_id = prim->first_execution_condition_id ;
61
+ current_iteration_id = prim->body_current_iteration_id ;
62
+ execution_condition_id = prim->body_execution_condition_id ;
63
+ num_iterations_id = prim->num_iteration_id ;
64
+ }
65
+
49
66
void update_primitive_map (const primitive_id& prevID, const primitive_id& newID, bool external_id = true ) {
50
67
if (external_id) {
51
68
for (auto & pm : input_primitive_maps) {
@@ -78,6 +95,18 @@ struct typed_program_node<loop> : public typed_program_node_base<loop> {
78
95
}
79
96
}
80
97
}
98
+
99
+ // Update ids
100
+ if (get_trip_count_id () == prevID)
101
+ trip_count_id = newID;
102
+ if (get_initial_execution_id () == prevID)
103
+ initial_execution_id = newID;
104
+ if (get_current_iteration_id () == prevID)
105
+ current_iteration_id = newID;
106
+ if (get_execution_condition_id () == prevID)
107
+ execution_condition_id = newID;
108
+ if (get_num_iterations_id () == prevID)
109
+ num_iterations_id = newID;
81
110
}
82
111
83
112
// current_iteration is necessary to calculate output layout in dynamic shape
0 commit comments