Skip to content

Commit 5147721

Browse files
[GPU] Update loop inst ids instead of getting from prim
1 parent 3777479 commit 5147721

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

src/plugins/intel_gpu/src/graph/include/loop_inst.h

+35-6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ struct typed_program_node<loop> : public typed_program_node_base<loop> {
2222
private:
2323
using parent = typed_program_node_base<loop>;
2424

25+
primitive_id trip_count_id;
26+
primitive_id initial_execution_id;
27+
primitive_id current_iteration_id;
28+
primitive_id execution_condition_id;
29+
primitive_id num_iterations_id;
30+
2531
std::vector<loop::io_primitive_map>& input_primitive_maps;
2632
std::vector<loop::io_primitive_map>& output_primitive_maps;
2733
std::vector<loop::backedge_mapping>& back_edges;
@@ -31,21 +37,32 @@ struct typed_program_node<loop> : public typed_program_node_base<loop> {
3137
parent(prim, prog),
3238
input_primitive_maps(prim->input_primitive_maps),
3339
output_primitive_maps(prim->output_primitive_maps),
34-
back_edges(prim->back_edges) {}
40+
back_edges(prim->back_edges) {
41+
set_primitive_ids(prim);
42+
}
3543

3644
program::ptr get_body_program() const { return get_primitive()->body_program; }
3745

38-
const primitive_id& get_trip_count_id() const { return get_primitive()->trip_count_id; }
39-
const primitive_id& get_initial_execution_id() const { return get_primitive()->first_execution_condition_id; }
40-
const primitive_id& get_current_iteration_id() const { return get_primitive()->body_current_iteration_id; }
41-
const primitive_id& get_execution_condition_id() const { return get_primitive()->body_execution_condition_id; }
42-
const primitive_id& get_num_iterations_id() const { return get_primitive()->num_iteration_id; }
46+
const primitive_id& get_trip_count_id() const { return trip_count_id; }
47+
const primitive_id& get_initial_execution_id() const { return initial_execution_id; }
48+
const primitive_id& get_current_iteration_id() const { return current_iteration_id; }
49+
const primitive_id& get_execution_condition_id() const { return execution_condition_id; }
50+
const primitive_id& get_num_iterations_id() const { return num_iterations_id; }
51+
4352
const int32_t get_max_num_iteration() const { return get_primitive()->max_num_iterations; }
4453

4554
const std::vector<loop::io_primitive_map>& get_input_primitive_maps() const { return input_primitive_maps; }
4655
const std::vector<loop::io_primitive_map>& get_output_primitive_maps() const { return output_primitive_maps; }
4756
const std::vector<loop::backedge_mapping>& get_back_edges() const { return back_edges;}
4857

58+
void set_primitive_ids(std::shared_ptr<loop> prim) {
59+
trip_count_id = prim->trip_count_id;
60+
initial_execution_id = prim->first_execution_condition_id;
61+
current_iteration_id = prim->body_current_iteration_id;
62+
execution_condition_id = prim->body_execution_condition_id;
63+
num_iterations_id = prim->num_iteration_id;
64+
}
65+
4966
void update_primitive_map(const primitive_id& prevID, const primitive_id& newID, bool external_id = true) {
5067
if (external_id) {
5168
for (auto& pm : input_primitive_maps) {
@@ -78,6 +95,18 @@ struct typed_program_node<loop> : public typed_program_node_base<loop> {
7895
}
7996
}
8097
}
98+
99+
// Update ids
100+
if (get_trip_count_id() == prevID)
101+
trip_count_id = newID;
102+
if (get_initial_execution_id() == prevID)
103+
initial_execution_id = newID;
104+
if (get_current_iteration_id() == prevID)
105+
current_iteration_id = newID;
106+
if (get_execution_condition_id() == prevID)
107+
execution_condition_id = newID;
108+
if (get_num_iterations_id() == prevID)
109+
num_iterations_id = newID;
81110
}
82111

83112
// current_iteration is necessary to calculate output layout in dynamic shape

0 commit comments

Comments
 (0)