Skip to content

Commit cf6cb43

Browse files
authored
[GPU] Match concat input format (#25891)
### Details: - *Make the input format of concatenation same as output format* ### Tickets: - *147692*
1 parent f534e5a commit cf6cb43

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

src/plugins/intel_gpu/src/graph/graph_optimizer/reorder_inputs.cpp

+23-2
Original file line numberDiff line numberDiff line change
@@ -937,14 +937,35 @@ void reorder_inputs::run(program& p, layout_optimizer& lo, reorder_factory& rf)
937937
}
938938
};
939939

940+
const auto reorder_input_concat = [&p, &rf](typed_program_node<concatenation>& concat_node) {
941+
auto output_layout = concat_node.get_output_layout();
942+
// Iterate over all dependencies of the concat node
943+
for (size_t i = 0; i < concat_node.get_dependencies().size(); ++i) {
944+
auto dep = concat_node.get_dependency_with_port(i);
945+
const auto& input = dep.first;
946+
auto input_layout = input->get_output_layout();
947+
// Change input data type of concat node from input format to output format
948+
if (input_layout.format != output_layout.format) {
949+
auto new_layout = input_layout;
950+
new_layout.format = output_layout.format;
951+
auto new_input = rf.get_reorder(input->id(), dep.second, input_layout, new_layout);
952+
if (new_input.first) {
953+
p.add_intermediate(new_input.first, concat_node, i);
954+
concat_node.get_dependency_with_port(i).first->recalc_output_layout();
955+
}
956+
}
957+
}
958+
};
959+
940960
for (auto& prim : p.get_processing_order()) {
941-
program_helpers::do_for_types<detection_output, deconvolution, convolution, fully_connected, pooling>(
961+
program_helpers::do_for_types<detection_output, deconvolution, convolution, fully_connected, pooling, concatenation>(
942962
*prim,
943963
reorder_input_detection_output,
944964
reorder_input_and_weights_deconvolution,
945965
reorder_convolution,
946966
reorder_input_fully_connected,
947-
reorder_input_pooling);
967+
reorder_input_pooling,
968+
reorder_input_concat);
948969
}
949970

950971
for (auto n : p.get_processing_order()) {

0 commit comments

Comments
 (0)