@@ -937,14 +937,35 @@ void reorder_inputs::run(program& p, layout_optimizer& lo, reorder_factory& rf)
937
937
}
938
938
};
939
939
940
+ const auto reorder_input_concat = [&p, &rf](typed_program_node<concatenation>& concat_node) {
941
+ auto output_layout = concat_node.get_output_layout ();
942
+ // Iterate over all dependencies of the concat node
943
+ for (size_t i = 0 ; i < concat_node.get_dependencies ().size (); ++i) {
944
+ auto dep = concat_node.get_dependency_with_port (i);
945
+ const auto & input = dep.first ;
946
+ auto input_layout = input->get_output_layout ();
947
+ // Change input data type of concat node from input format to output format
948
+ if (input_layout.format != output_layout.format ) {
949
+ auto new_layout = input_layout;
950
+ new_layout.format = output_layout.format ;
951
+ auto new_input = rf.get_reorder (input->id (), dep.second , input_layout, new_layout);
952
+ if (new_input.first ) {
953
+ p.add_intermediate (new_input.first , concat_node, i);
954
+ concat_node.get_dependency_with_port (i).first ->recalc_output_layout ();
955
+ }
956
+ }
957
+ }
958
+ };
959
+
940
960
for (auto & prim : p.get_processing_order ()) {
941
- program_helpers::do_for_types<detection_output, deconvolution, convolution, fully_connected, pooling>(
961
+ program_helpers::do_for_types<detection_output, deconvolution, convolution, fully_connected, pooling, concatenation >(
942
962
*prim,
943
963
reorder_input_detection_output,
944
964
reorder_input_and_weights_deconvolution,
945
965
reorder_convolution,
946
966
reorder_input_fully_connected,
947
- reorder_input_pooling);
967
+ reorder_input_pooling,
968
+ reorder_input_concat);
948
969
}
949
970
950
971
for (auto n : p.get_processing_order ()) {
0 commit comments