@@ -664,9 +664,32 @@ event::ptr primitive_inst::realloc_if_needed() {
664
664
updated_layouts[0 ] = layout (current_buf_shape, updated_layouts[0 ].data_type , updated_layouts[0 ].format );
665
665
}
666
666
667
+ int32_t tmp_prealloc_count = get_prealloc_iter_num ();
668
+ GPU_DEBUG_IF (debug_config->mem_preallocation_params .is_initialized ) {
669
+ // If debug config is set, repsect the config most
670
+ tmp_prealloc_count = -1 ;
671
+ }
672
+
667
673
// If we allocated too large memory, reclaim the memory.
668
674
for (size_t i = 0 ; i < updated_layouts.size (); ++i) {
669
- if (updated_layouts[i].get_buffer_size ().count () * 10 < _max_output_layout_count[i]) {
675
+ bool reclaim = 0 ;
676
+ size_t required_buffer_size = 0 ;
677
+ if (_node->is_type <kv_cache>() && i == 0 ) {
678
+ // Relax reclaiming condition for kv cache
679
+ const auto & desc = _node->as <kv_cache>().get_primitive ();
680
+ auto prealloc_shape = updated_layouts[i].get_shape ();
681
+ const auto shape_rank = prealloc_shape.size ();
682
+ auto seq_axis =
683
+ static_cast <int32_t >(desc->concat_axis >= 0 ? desc->concat_axis : shape_rank + desc->concat_axis );
684
+ prealloc_shape[seq_axis] += tmp_prealloc_count;
685
+ required_buffer_size = std::accumulate (prealloc_shape.begin (), prealloc_shape.end (), size_t (1 ), std::multiplies<size_t >());
686
+ } else {
687
+ required_buffer_size = (updated_layouts[i].get_buffer_size ().count ());
688
+ }
689
+ if (required_buffer_size * 10 < _max_output_layout_count[i]) {
690
+ reclaim = true ;
691
+ }
692
+ if (reclaim) {
670
693
GPU_DEBUG_TRACE_DETAIL << id () << " : Updated output[" << i << " ] size " << updated_layouts[i].get_buffer_size ().count ()
671
694
<< " is much smaller than current memory size! " << _max_output_layout_count[i]
672
695
<< " Reset memory of output " << i << std::endl;
@@ -681,31 +704,51 @@ event::ptr primitive_inst::realloc_if_needed() {
681
704
return ev;
682
705
}
683
706
684
- int32_t tmp_prealloc_count = get_prealloc_iter_num ();
685
- GPU_DEBUG_IF (debug_config->mem_preallocation_params .is_initialized ) {
686
- // If debug config is set, repsect the config most
687
- tmp_prealloc_count = -1 ;
688
- }
689
-
690
707
for (size_t i = 0 ; i < actual_layouts.size (); ++i) {
691
708
bool can_reuse_buffer = (_outputs[i] && updated_layouts[i].get_buffer_size ().count () <= _max_output_layout_count[i]);
692
- std::pair<bool , ov::Shape> prealloc_info =
693
- sp.predict_preallocation_shape (id (), updated_layouts[i], can_reuse_buffer, i, tmp_prealloc_count);
709
+ std::pair<bool , ov::Shape> prealloc_info;
710
+ if (_node->is_type <kv_cache>() && i == 0 ) {
711
+ const auto & desc = _node->as <kv_cache>().get_primitive ();
712
+ auto shape_rank = updated_layouts[i].get_shape ().size ();
713
+ auto seq_axis =
714
+ static_cast <int32_t >(desc->concat_axis >= 0 ? desc->concat_axis : shape_rank + desc->concat_axis );
715
+ prealloc_info = sp.predict_preallocation_shape (id (), updated_layouts[i], false , i, tmp_prealloc_count, seq_axis);
716
+ } else {
717
+ prealloc_info = sp.predict_preallocation_shape (id (), updated_layouts[i], can_reuse_buffer, i, tmp_prealloc_count);
718
+ }
694
719
if (prealloc_info.first && sp.can_preallocate (ov::shape_size (prealloc_info.second ) * (dt_sizes_in_B[i]))) {
695
720
auto new_layout = updated_layouts[i];
696
721
new_layout.set_partial_shape (prealloc_info.second );
697
722
updated_params.output_layouts [i] = new_layout;
698
723
}
699
-
700
724
if (updated_params.output_layouts [i].get_buffer_size ().count () < updated_layouts[i].get_buffer_size ().count ()) {
701
725
updated_params.output_layouts [i] = updated_layouts[i];
702
726
}
703
727
704
728
if (can_reuse_buffer) {
705
- GPU_DEBUG_TRACE_DETAIL << id () << " : reuse previously allocated output buffer - "
729
+ GPU_DEBUG_TRACE_DETAIL << id () << " : reuse previously allocated output buffer[ " << i << " ] - "
706
730
<< actual_layouts[i].get_buffer_size ().count () << " /" << _max_output_layout_count[i]
707
731
<< std::endl;
708
- _outputs[i] = _network.get_engine ().reinterpret_buffer (*_outputs[i], actual_layouts[i]);
732
+ if (_node->is_type <kv_cache>() && (i == 0 )) {
733
+ // kv_cache has already assigned memory.
734
+ // No need to reinterpret output memory but need to update padding
735
+ const auto & desc = _node->as <kv_cache>().get_primitive ();
736
+ auto & present_layout = _impl_params->output_layouts [i];
737
+ const auto present_layout_rank = present_layout.get_partial_shape ().size ();
738
+ const auto sequence_axis =
739
+ desc->concat_axis >= 0 ? desc->concat_axis : present_layout_rank + desc->concat_axis ;
740
+
741
+ const auto sequence_axis_legacy = kv_cache_inst::get_sequence_axis_legacy (sequence_axis, present_layout_rank);
742
+ auto max_pad = kv_cache_inst::get_max_pad (present_layout,
743
+ _max_output_layout_count[i],
744
+ sequence_axis_legacy,
745
+ " present_layout" );
746
+ kv_cache_inst::update_pad (present_layout, max_pad, sequence_axis_legacy);
747
+ GPU_DEBUG_TRACE_DETAIL << _impl_params->output_layouts [i].to_string () << std::endl;
748
+ set_shape_change ();
749
+ } else {
750
+ _outputs[i] = _network.get_engine ().reinterpret_buffer (*_outputs[i], actual_layouts[i]);
751
+ }
709
752
// TODO: check need_reset_output_memory per output
710
753
if (need_reset_output_memory () && !can_be_optimized ()) {
711
754
GPU_DEBUG_TRACE_DETAIL << id () << " : Need reset output memory considering user" << std::endl;
@@ -740,7 +783,7 @@ event::ptr primitive_inst::realloc_if_needed() {
740
783
741
784
// Set variable memory same as output memory
742
785
if (_node->is_type <kv_cache>()) {
743
- auto desc = _node->as <kv_cache>().get_primitive ();
786
+ const auto & desc = _node->as <kv_cache>().get_primitive ();
744
787
auto & variable = get_network ().get_variable (desc->variable_info .variable_id );
745
788
auto present_layout = _impl_params->output_layouts [0 ];
746
789
auto present_layout_rank = present_layout.get_partial_shape ().size ();
@@ -760,7 +803,7 @@ event::ptr primitive_inst::realloc_if_needed() {
760
803
if (present_layout.data_padding .get_dynamic_pad_dims ().sizes ()[sequence_axis_legacy] == 1 ) {
761
804
// Apply padding of variable to make it be optimized in the next iteration
762
805
auto max_pad = kv_cache_inst::get_max_pad (present_layout,
763
- updated_params. output_layouts [0 ]. get_buffer_size (). count () ,
806
+ _max_output_layout_count [0 ],
764
807
sequence_axis_legacy,
765
808
" present_layout" );
766
809
if (max_pad > 0 ) {
@@ -783,7 +826,7 @@ event::ptr primitive_inst::realloc_if_needed() {
783
826
GPU_DEBUG_TRACE_DETAIL << id () << " : Update variable " << variable.get_name ()
784
827
<< " 's layout with allocated kv cache output: " << present_layout.to_short_string ()
785
828
<< " (is_set = " << variable.is_set () << " ) " << std::endl;
786
- variable.set_layout ( present_layout);
829
+ variable.set_memory (_outputs[ 0 ], present_layout);
787
830
}
788
831
} else {
789
832
GPU_DEBUG_TRACE_DETAIL << id () << " : Update variable " << variable.get_name ()
@@ -1036,8 +1079,10 @@ void primitive_inst::update_paddings() {
1036
1079
auto reset_pad = [](kernel_impl_params& params, const program_node* node) {
1037
1080
params.output_layouts [0 ].data_padding = node->get_output_layout (0 ).data_padding ;
1038
1081
};
1039
- if (_node->is_type <read_value>()) {
1040
- auto & variable = get_network ().get_variable (_node->as <read_value>().get_primitive ()->variable_id );
1082
+ if (_node->is_type <read_value>() || _node->is_type <kv_cache>()) {
1083
+ auto variable_id = _node->is_type <read_value>() ? (_node->as <read_value>().get_primitive ()->variable_id )
1084
+ : (_node->as <kv_cache>().get_primitive ()->variable_info .variable_id );
1085
+ auto & variable = get_network ().get_variable (variable_id);
1041
1086
// Reset paddings for read_value and users with dynamic pad when variable is reset
1042
1087
// to avoid wrong pad used for some nodes due to pad propagation logic (which uses previous iter pad values)
1043
1088
if (!variable.is_set ()) {
@@ -1054,6 +1099,7 @@ void primitive_inst::update_paddings() {
1054
1099
}
1055
1100
return ;
1056
1101
}
1102
+
1057
1103
if (_node->is_type <gather>() && _impl_params->output_layouts [0 ].data_padding .get_dynamic_pad_dims () != tensor (0 )) {
1058
1104
if (can_be_optimized ())
1059
1105
_impl_params->output_layouts [0 ] = _impl_params->input_layouts [0 ];
@@ -1141,7 +1187,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
1141
1187
if (_impl_params->get_input_layout (0 ).count () == 0 ) {
1142
1188
return ;
1143
1189
}
1144
- auto desc = _node->as <kv_cache>().get_primitive ();
1190
+ const auto & desc = _node->as <kv_cache>().get_primitive ();
1145
1191
auto & past_layout = _impl_params->input_layouts [0 ];
1146
1192
auto & present_layout = _impl_params->output_layouts [0 ];
1147
1193
const auto & sequence_axis = desc->concat_axis ;
0 commit comments