@@ -512,6 +512,12 @@ bool crop_in_place_optimization::match(const program_node& node,
512
512
if (node.get_program ().is_body_program () && node.get_dependency (0 ).is_type <lstm_elt>()) {
513
513
return false ;
514
514
}
515
+
516
+ GPU_DEBUG_GET_INSTANCE (debug_config);
517
+ GPU_DEBUG_IF (debug_config->disable_runtime_buffer_fusing && node.is_dynamic ()) {
518
+ return false ;
519
+ }
520
+
515
521
// optimization is available for cropping across depth(features) or batch
516
522
// if output padding has defined padding across features already it wouldn't
517
523
// work because it expect to have zeros in the padded area.
@@ -553,18 +559,22 @@ bool crop_in_place_optimization::optimize(crop_node& node) {
553
559
node.get_primitive ()->axis ,
554
560
false );
555
561
} else if (can_crop_be_optimized_simple_data_format (crop_layout, input_layout)) {
556
- std::vector<layout> reshape_layouts;
557
- if (node.get_users ().front ()->is_type <reshape>() && node.get_users ().front ()->as <reshape>().is_runtime_propagatable_padding ()) {
558
- reshape_layouts.push_back (node.get_users ().front ()->get_output_layout ());
562
+ std::pair<const program_node*, layout> user_info;
563
+ if (node.get_users ().front ()->is_type <reshape>()) {
564
+ auto & reshape_node = node.get_users ().front ()->as <reshape>();
565
+ if (reshape_node.is_runtime_propagatable_padding ()) {
566
+ user_info.first = &reshape_node;
567
+ user_info.second = reshape_node.get_output_layout ();
568
+ }
559
569
}
560
570
update_in_place_crop_padding_simple_data_format (crop_layout,
561
571
input_layout,
562
- reshape_layouts ,
572
+ user_info ,
563
573
crop_params->input_offsets [0 ],
564
574
node.get_primitive ()->axis ,
565
575
false );
566
- if (reshape_layouts. size () > 0 ) {
567
- node.get_users ().front ()->set_output_layout (reshape_layouts[ 0 ] );
576
+ if (user_info. first ) {
577
+ node.get_users ().front ()->set_output_layout (user_info. second );
568
578
}
569
579
}
570
580
node.set_output_layout (crop_layout);
@@ -632,24 +642,51 @@ void crop_in_place_optimization::update_in_place_crop_padding_along_feature(cons
632
642
633
643
void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format (layout& crop_layout,
634
644
layout& input_layout,
635
- std::vector< layout>& user_layouts ,
645
+ std::pair< const program_node*, layout>& user_info ,
636
646
const tensor offsets,
637
647
size_t crop_axis,
638
648
bool is_runtime) {
639
- auto crop_axis_legacy = crop_axis;
640
- if (crop_axis_legacy >= 2 ) {
641
- auto spatial_axis = crop_axis_legacy - 2 ;
642
- // Default and minimum number of dimensions is 4
643
- auto spatial_size = std::max<size_t >(crop_layout.get_partial_shape ().size (), 4 ) - 2 ;
644
- crop_axis_legacy = spatial_size - spatial_axis - 1 + 2 ;
645
- }
649
+ auto convert_axis_to_legacy = [](size_t axis, size_t rank) {
650
+ auto axis_legacy = axis;
651
+ if (axis_legacy >= 2 ) {
652
+ auto spatial_axis = axis_legacy - 2 ;
653
+ // Default and minimum number of dimensions is 4
654
+ auto spatial_size = std::max<size_t >(rank, 4 ) - 2 ;
655
+ axis_legacy = spatial_size - spatial_axis - 1 + 2 ;
656
+ }
657
+
658
+ return axis_legacy;
659
+ };
660
+
661
+ auto crop_axis_legacy = convert_axis_to_legacy (crop_axis, crop_layout.get_partial_shape ().size ());
662
+
646
663
// If it's build-time and node is dynamic, only dynamic padding is set first
647
664
if ((crop_layout.is_dynamic () || input_layout.is_dynamic ()) && !is_runtime) {
648
665
auto dyn_pad_sizes = tensor (0 ).sizes ();
649
666
dyn_pad_sizes[crop_axis_legacy] = 1 ;
650
667
crop_layout.data_padding .set_dynamic_pad (tensor (dyn_pad_sizes));
651
- for (auto & user_layout : user_layouts) {
652
- user_layout.data_padding .set_dynamic_pad (tensor (dyn_pad_sizes));
668
+
669
+ if (user_info.first && user_info.first ->is_type <reshape>()) {
670
+ auto reshape_desc = user_info.first ->as <reshape>().get_primitive ();
671
+ auto reshape_mode = reshape_desc->mode ;
672
+ if (reshape_mode == reshape::reshape_mode::base) {
673
+ user_info.second .data_padding .set_dynamic_pad (tensor (dyn_pad_sizes));
674
+ } else if (reshape_mode == reshape::reshape_mode::unsqueeze || reshape_mode == reshape::reshape_mode::squeeze) {
675
+ auto reshape_ps = user_info.second .get_partial_shape ();
676
+ auto output_pattern = reshape_desc->output_pattern ;
677
+
678
+ auto reshape_axis = crop_axis;
679
+ for (size_t i = 0 ; i < output_pattern.size (); i++) {
680
+ if (output_pattern[i] <= static_cast <int64_t >(reshape_axis)) {
681
+ reshape_axis += reshape_mode == reshape::reshape_mode::unsqueeze ? 1 : -1 ;
682
+ }
683
+ }
684
+
685
+ auto dyn_pad_mask = tensor (0 ).sizes ();
686
+ auto reshape_axis_legacy = convert_axis_to_legacy (reshape_axis, reshape_ps.size ());
687
+ dyn_pad_mask[reshape_axis_legacy] = 1 ;
688
+ user_info.second .data_padding .set_dynamic_pad (tensor (dyn_pad_mask));
689
+ }
653
690
}
654
691
return ;
655
692
}
@@ -673,14 +710,40 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
673
710
auto dyn_pad_sizes = lower_sizes;
674
711
dyn_pad_sizes[crop_axis_legacy] = 1 ;
675
712
crop_layout.data_padding = padding (lower_sizes, upper_sizes, 0 .f , tensor (dyn_pad_sizes));
676
- for (auto & user_layout : user_layouts) {
677
- auto reshape_rank = user_layout.get_partial_shape ().size ();
678
- auto reshape_last_dim = user_layout.get_partial_shape ().to_shape ()[reshape_rank - 1 ];
679
- if (lower_sizes[crop_axis_legacy])
680
- lower_sizes[crop_axis_legacy] /= reshape_last_dim;
681
- if (upper_sizes[crop_axis_legacy])
682
- upper_sizes[crop_axis_legacy] /= reshape_last_dim;
683
- user_layout.data_padding = padding (lower_sizes, upper_sizes, 0 .f , tensor (dyn_pad_sizes));
713
+ if (user_info.first ) {
714
+ auto reshape_desc = user_info.first ->as <reshape>().get_primitive ();
715
+ auto reshape_mode = reshape_desc->mode ;
716
+ if (reshape_mode == reshape::reshape_mode::base) {
717
+ auto reshape_rank = user_info.second .get_partial_shape ().size ();
718
+ auto reshape_last_dim = user_info.second .get_partial_shape ().to_shape ()[reshape_rank - 1 ];
719
+ if (lower_sizes[crop_axis_legacy])
720
+ lower_sizes[crop_axis_legacy] /= reshape_last_dim;
721
+ if (upper_sizes[crop_axis_legacy])
722
+ upper_sizes[crop_axis_legacy] /= reshape_last_dim;
723
+ user_info.second .data_padding = padding (lower_sizes, upper_sizes, 0 .f , tensor (dyn_pad_sizes));
724
+ } else {
725
+ auto reshape_ps = user_info.second .get_partial_shape ();
726
+ auto output_pattern = reshape_desc->output_pattern ;
727
+
728
+ auto reshape_axis = crop_axis;
729
+ for (size_t i = 0 ; i < output_pattern.size (); i++) {
730
+ if (output_pattern[i] <= static_cast <int64_t >(reshape_axis)) {
731
+ reshape_axis += reshape_mode == reshape::reshape_mode::unsqueeze ? 1 : -1 ;
732
+ }
733
+ }
734
+
735
+ const auto output_rank = std::max (reshape_ps.size (), static_cast <size_t >(4 ));
736
+ std::vector<int32_t > reshape_lower_sizes (output_rank, 0 );
737
+ std::vector<int32_t > reshape_upper_sizes (output_rank, 0 );
738
+ std::vector<int32_t > reshape_dyn_pad_mask (output_rank, 0 );
739
+
740
+ const auto reshape_axis_legacy = convert_axis_to_legacy (reshape_axis, reshape_ps.size ());
741
+ reshape_lower_sizes[reshape_axis_legacy] = lower_sizes[crop_axis_legacy];
742
+ reshape_upper_sizes[reshape_axis_legacy] = upper_sizes[crop_axis_legacy];
743
+ reshape_dyn_pad_mask[reshape_axis_legacy] = 1 ;
744
+
745
+ user_info.second .data_padding = padding (reshape_lower_sizes, reshape_upper_sizes, 0 .f , tensor (reshape_dyn_pad_mask));
746
+ }
684
747
}
685
748
} else {
686
749
crop_layout.data_padding = padding (lower_sizes, upper_sizes);
@@ -743,18 +806,23 @@ void prepare_buffer_fusing::run(program& p) {
743
806
node.get_primitive ()->axis ,
744
807
false );
745
808
} else if (crop_in_place_optimization::can_crop_be_optimized_simple_data_format (crop_layout, pred_layout)) {
809
+ std::pair<const program_node*, layout> user_info;
746
810
std::vector<layout> reshape_layouts;
747
- if (node.get_users ().front ()->is_type <reshape>() && node.get_users ().front ()->as <reshape>().is_runtime_propagatable_padding ()) {
748
- reshape_layouts.push_back (node.get_users ().front ()->get_output_layout ());
811
+ if (node.get_users ().front ()->is_type <reshape>()) {
812
+ auto & reshape_node = node.get_users ().front ()->as <reshape>();
813
+ if (reshape_node.is_runtime_propagatable_padding ()) {
814
+ user_info.first = &reshape_node;
815
+ user_info.second = reshape_node.get_output_layout ();
816
+ }
749
817
}
750
818
crop_in_place_optimization::update_in_place_crop_padding_simple_data_format (crop_layout,
751
819
pred_layout,
752
- reshape_layouts ,
820
+ user_info ,
753
821
crop_params->input_offsets [0 ],
754
822
node.get_primitive ()->axis ,
755
823
false );
756
- if (reshape_layouts. size () > 0 ) {
757
- node.get_users ().front ()->set_output_layout (reshape_layouts[ 0 ] );
824
+ if (user_info. first ) {
825
+ node.get_users ().front ()->set_output_layout (user_info. second );
758
826
}
759
827
}
760
828
node.set_output_layout (crop_layout);
0 commit comments